diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index 4d87bb571fa0fc628c1fe8e0a80466189c61d2d9..6f70edfc8616a037b017336ffe5b03ede2b2f689 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -89,7 +89,7 @@ from arkindex.documents.api.ml import ( from arkindex.documents.api.search import CorpusSearch from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve from arkindex.project.openapi import OpenApiSchemaView -from arkindex.training.api import ModelsList, ModelVersionsList, ModelVersionsUpdate +from arkindex.training.api import ModelsList, ModelVersionsList, ModelVersionsRetrieve from arkindex.users.api import ( CredentialsList, CredentialsRetrieve, @@ -239,7 +239,7 @@ api = [ path('process/<uuid:pk>/apply/', ApplyProcessTemplate.as_view(), name='apply-process-template'), # ML models training - path('modelversion/<uuid:pk>/', ModelVersionsUpdate.as_view(), name='model-version-update'), + path('modelversion/<uuid:pk>/', ModelVersionsRetrieve.as_view(), name='model-version-retrieve'), path('models/', ModelsList.as_view(), name='models'), path('model/<uuid:pk>/versions/', ModelVersionsList.as_view(), name='model-versions'), diff --git a/arkindex/training/api.py b/arkindex/training/api.py index a01b5a56e28367c8bce9570b48f9fe754c984820..0cf61e569ccc831928fae095f1f00998f22977b8 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -3,7 +3,7 @@ from django.shortcuts import get_object_or_404 from django.utils.functional import cached_property from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view from rest_framework.exceptions import PermissionDenied, ValidationError -from rest_framework.generics import ListCreateAPIView, UpdateAPIView +from rest_framework.generics import ListCreateAPIView, RetrieveUpdateDestroyAPIView from arkindex.project.mixins import TrainingModelMixin from arkindex.project.permissions import IsVerified @@ -76,7 +76,7 @@ class ModelVersionsList(TrainingModelMixin, ListCreateAPIView): if access_level < Role.Contributor.value: filters &= Q(tag__isnull=False, state=ModelVersionState.Available) - # Count all versions for models with contributor+ access + # Order version by creation date queryset = ModelVersion.objects.filter(filters).order_by('-created') return super().filter_queryset(queryset) @@ -93,8 +93,13 @@ class ModelVersionsList(TrainingModelMixin, ListCreateAPIView): patch=extend_schema( description='Partially update a version of a Machine Learning model.\n\nRequires a **contributor** access to the model.' ), + delete=extend_schema( + description='Delete a model version.\n\nRequires an **admin** access on the related model.' + ) ) -class ModelVersionsUpdate(TrainingModelMixin, UpdateAPIView): +class ModelVersionsRetrieve(TrainingModelMixin, RetrieveUpdateDestroyAPIView): + """Retrieve a version of Machine Learning model + """ permission_classes = (IsVerified, ) serializer_class = ModelVersionSerializer queryset = ModelVersion.objects.none() @@ -106,13 +111,41 @@ class ModelVersionsUpdate(TrainingModelMixin, UpdateAPIView): def get_serializer_context(self): context = super().get_serializer_context() - context['model'] = self.get_object().model + context['model'] = self.model return context + @cached_property + def model(self): + return get_object_or_404(self.get_queryset(), id=self.kwargs['pk']).model + + @cached_property + def model_access_rights(self): + return get_max_level(self.request.user, self.model) + def check_object_permissions(self, request, model_version): - super().check_object_permissions(request, model_version.model) - if not self.has_write_access(model_version.model): + access_level = self.model_access_rights + + if model_version.state == ModelVersionState.Available and model_version.tag is not None: + needed_level, error_msg = Role.Guest.value, 'Guest' + else: + needed_level, error_msg = Role.Contributor.value, 'Contributor' + + if not access_level or access_level < needed_level: + raise PermissionDenied(detail=f'You need a {error_msg} access to the model to retrieve this version.') + return super().check_object_permissions(request, model_version) + + def perform_update(self, serializer): + access_level = self.model_access_rights + if not access_level or access_level < Role.Contributor.value: raise PermissionDenied(detail='You need a Contributor access to the model to edit this version.') + return super().perform_update(serializer) + + def perform_destroy(self, instance): + access_level = self.model_access_rights + if not access_level or access_level < Role.Admin.value: + raise PermissionDenied(detail='You need an Admin access to the model to destroy this version.') + + super().perform_destroy(instance) @extend_schema(tags=['training']) diff --git a/arkindex/training/tests/test_model_api.py b/arkindex/training/tests/test_model_api.py index 7cedcf3405265c8712412f448c259e13173e24dc..fd0c1733bf6f387a3a217a0bebdbbda066aa15ee 100644 --- a/arkindex/training/tests/test_model_api.py +++ b/arkindex/training/tests/test_model_api.py @@ -26,6 +26,24 @@ def _deserialize_model(model, access_rights): } +def _deserialize_model_version(model_version): + # Only stringify if we get an ID + parent = str(model_version.parent) if model_version.parent else model_version.parent + return { + 'id': str(model_version.id), + 'model_id': str(model_version.model_id), + 'parent': parent, + 'description': model_version.description, + 'tag': model_version.tag, + 'hash': model_version.hash, + 'archive_hash': model_version.archive_hash, + 'state': model_version.state.value, + 'size': model_version.size, + 'configuration': model_version.configuration, + 's3_url': model_version.s3_url, + } + + class TestModelAPI(FixtureAPITestCase): """ Test model and model version api @@ -268,13 +286,18 @@ class TestModelAPI(FixtureAPITestCase): } ) - def test_partial_update_model_version_requires_contributor(self): + @patch('arkindex.project.aws.s3.Object') + @patch('arkindex.project.aws.S3FileMixin.exists') + def test_partial_update_model_version_requires_contributor(self, exists, s3_object): """ Can't partial update a model version as guest """ + s3_object().content_length = self.model_version3.size + s3_object().e_tag = self.model_version3.archive_hash + exists.return_value = True self.client.force_login(self.user3) - with self.assertNumQueries(6): - response = self.client.patch(reverse("api:model-version-update", kwargs={"pk": str(self.model_version1.id)}), {'state': 'available'}) + with self.assertNumQueries(9): + response = self.client.patch(reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}), {'state': 'available'}) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the model to edit this version."}) @@ -295,8 +318,8 @@ class TestModelAPI(FixtureAPITestCase): "state" : "available", "configuration": {"parameter1": "value1"} } - with self.assertNumQueries(11): - response = self.client.patch(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), request, format='json') + with self.assertNumQueries(10): + response = self.client.patch(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { @@ -319,8 +342,8 @@ class TestModelAPI(FixtureAPITestCase): """ self.client.force_login(self.user1) exists.return_value = False - with self.assertNumQueries(8): - response = self.client.patch(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), { + with self.assertNumQueries(7): + response = self.client.patch(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), { 'state': 'available' }) @@ -335,8 +358,8 @@ class TestModelAPI(FixtureAPITestCase): exists.return_value = True s3_object().content_length = self.model_version1.size + 1 self.client.force_login(self.user1) - with self.assertNumQueries(8): - response = self.client.patch(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), { + with self.assertNumQueries(7): + response = self.client.patch(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), { 'state': 'available', }) @@ -352,8 +375,8 @@ class TestModelAPI(FixtureAPITestCase): s3_object().content_length = self.model_version1.size s3_object().e_tag = f'"{self.model_version2.hash}"' self.client.force_login(self.user1) - with self.assertNumQueries(9): - response = self.client.patch(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), { + with self.assertNumQueries(8): + response = self.client.patch(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), { 'state': 'available', }) @@ -376,21 +399,26 @@ class TestModelAPI(FixtureAPITestCase): "tag": self.model_version2.tag, "state": "available", } - with self.assertNumQueries(10): - response = self.client.patch(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), request, format='json') + with self.assertNumQueries(9): + response = self.client.patch(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]}) - def test_update_model_version_requires_contributor(self): + @patch('arkindex.project.aws.s3.Object') + @patch('arkindex.project.aws.S3FileMixin.exists') + def test_update_model_version_requires_contributor(self, exists, s3_object): """ Can't update a model version with guest access rights to the model """ + s3_object().content_length = self.model_version3.size + s3_object().e_tag = self.model_version3.archive_hash + exists.return_value = True self.client.force_login(self.user3) request = self.build_model_version_update_request() - with self.assertNumQueries(6): - response = self.client.put(reverse("api:model-version-update", kwargs={"pk": str(self.model_version1.id)}), request, format='json') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + with self.assertNumQueries(9): + response = self.client.put(reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}), request, format='json') + # self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the model to edit this version."}) def test_update_model_version_requires_all_parameters(self): @@ -398,8 +426,8 @@ class TestModelAPI(FixtureAPITestCase): Update endpoint requires every parameter """ self.client.force_login(self.user1) - with self.assertNumQueries(8): - response = self.client.put(reverse("api:model-version-update", kwargs={"pk": str(self.model_version1.id)}), {}) + with self.assertNumQueries(7): + response = self.client.put(reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version1.id)}), {}) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {'description': ['This field is required.'], 'state': ['This field is required.'], 'configuration': ['This field is required.']}) @@ -418,7 +446,7 @@ class TestModelAPI(FixtureAPITestCase): self.client.force_login(self.user1) params = self.build_model_version_update_request() for param, value in params.items(): - response = self.client.put(reverse("api:model-version-update", kwargs={"pk": str(self.model_version1.id)}), {param: value}, format='json') + response = self.client.put(reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version1.id)}), {param: value}, format='json') self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {other_param: ['This field is required.'] for other_param in params.keys() if other_param != param}, f'{param}') @@ -439,8 +467,8 @@ class TestModelAPI(FixtureAPITestCase): "state": "available", "configuration" : {"parameter1": "value1"} } - with self.assertNumQueries(11): - response = self.client.put(reverse('api:model-version-update', kwargs={"pk": str(self.model_version2.id)}), request, format='json') + with self.assertNumQueries(10): + response = self.client.put(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version2.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertDictEqual(response.json(), { "id": str(self.model_version2.id), @@ -463,8 +491,8 @@ class TestModelAPI(FixtureAPITestCase): self.client.force_login(self.user1) exists.return_value = False request = self.build_model_version_update_request() - with self.assertNumQueries(8): - response = self.client.put(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), request, format='json') + with self.assertNumQueries(7): + response = self.client.put(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"state": ['Archive has not been uploaded']}) @@ -478,8 +506,8 @@ class TestModelAPI(FixtureAPITestCase): s3_object().content_length = self.model_version1.size + 1 self.client.force_login(self.user1) request = self.build_model_version_update_request() - with self.assertNumQueries(8): - response = self.client.put(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), request, format='json') + with self.assertNumQueries(7): + response = self.client.put(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"state": [f'Uploaded file size is {self.model_version1.size + 1} bytes, expected {self.model_version1.size} bytes']}) @@ -494,8 +522,8 @@ class TestModelAPI(FixtureAPITestCase): s3_object().e_tag = f'"{self.model_version2.archive_hash}"' self.client.force_login(self.user1) request = self.build_model_version_update_request() - with self.assertNumQueries(9): - response = self.client.put(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), request, format='json') + with self.assertNumQueries(8): + response = self.client.put(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"state": ['MD5 hashes do not match']}) @@ -673,3 +701,57 @@ class TestModelAPI(FixtureAPITestCase): response = self.client.get(reverse('api:model-versions', kwargs={"pk": str(self.model2.id)})) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertListEqual([version["id"] for version in response.json()['results']], [str(self.model_version4.id), str(self.model_version3.id)]) + + def test_destroy_model_versions_requires_admin(self): + """To destroy a model version, you need admin rights on the model. + """ + self.client.force_login(self.user2) + with self.assertNumQueries(7): + response = self.client.delete(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "You need an Admin access to the model to destroy this version."}) + + def test_destroy_model_versions(self): + """To destroy a model version, you need admin rights on the model. + """ + self.client.force_login(self.user1) + with self.assertNumQueries(9): + response = self.client.delete(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)})) + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + + def test_retrieve_model_versions_require_guest(self): + """To retrieve a model version with a set tag and state==Available, you need guest rights on the model. + """ + self.client.force_login(self.user1) + with self.assertNumQueries(7): + response = self.client.get(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version3.id)})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "You need a Guest access to the model to retrieve this version."}) + + def test_retrieve_model_versions_tag_available(self): + """Retrieve a model version with a set tag and state==Available with guest rights on the model. + """ + self.client.force_login(self.user3) + with self.assertNumQueries(7): + response = self.client.get(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version3.id)})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertDictEqual(response.json(), _deserialize_model_version(self.model_version3)) + + def test_retrieve_model_versions_require_contributor(self): + """To retrieve a model version with no set tag or state!=Available, you need contributor rights on the model. + """ + self.client.force_login(self.user3) + with self.assertNumQueries(7): + response = self.client.get(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version4.id)})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the model to retrieve this version."}) + + def test_retrieve_model_versions(self): + """Retrieve a model version with no set tag or state!=Available with contributor rights on the model. + """ + self.client.force_login(self.user2) + with self.assertNumQueries(7): + response = self.client.get(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version4.id)})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.maxDiff = None + self.assertDictEqual(response.json(), _deserialize_model_version(self.model_version4))