From bc0c2a69fbb751c28ea1ebab5ce3df2c09a407cd Mon Sep 17 00:00:00 2001 From: ml bonhomme <bonhomme@teklia.com> Date: Tue, 5 Mar 2024 15:30:00 +0000 Subject: [PATCH] Allow setting any readable model version as parent of a model version, not just versions from the same model --- arkindex/training/api.py | 7 +- arkindex/training/serializers.py | 2 +- arkindex/training/tests/test_model_api.py | 133 +++++++++++++++++++++- 3 files changed, 139 insertions(+), 3 deletions(-) diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 55acaf6f0d..353dc6e928 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -307,10 +307,15 @@ class ValidateModelVersion(TrainingModelMixin, GenericAPIView): # Set the current model version as erroneous and return the available one instance.state = ModelVersionState.Error instance.save(update_fields=["state"]) + # Set context + context = { + **self.get_serializer_context(), + "is_contributor": True + } return Response( ModelVersionSerializer( existing_model_version, - context={"is_contributor": True, "model": instance}, + context=context, ).data, status=status.HTTP_409_CONFLICT, ) diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 4cff4c8f0f..ab1e300b58 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -213,7 +213,7 @@ class ModelVersionSerializer(serializers.ModelSerializer): else: model = self.context.get("model") if model: - qs = ModelVersion.objects.filter(model_id=model.id) + qs = ModelVersion.objects.filter(model__in=Model.objects.readable(self.context["request"].user)) if getattr(self.instance, "id", None): qs = qs.exclude(id=self.instance.id) self.fields["parent"].queryset = qs diff --git a/arkindex/training/tests/test_model_api.py b/arkindex/training/tests/test_model_api.py index 7a264b2f3f..303a25e520 100644 --- a/arkindex/training/tests/test_model_api.py +++ b/arkindex/training/tests/test_model_api.py @@ -266,6 +266,82 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]}) + @patch("arkindex.project.aws.s3.meta.client.generate_presigned_url") + def test_create_model_version_any_parent_model_version(self, s3_presigned_url_mock): + """ + Any readable model version can be set as parent of a model version, not just versions of the same model + """ + self.client.force_login(self.user1) + s3_presigned_url_mock.return_value = "http://s3/upload_put_url" + fake_now = timezone.now() + # To mock the creation date + with patch("django.utils.timezone.now") as mock_now: + mock_now.return_value = fake_now + with self.assertNumQueries(6): + response = self.client.post( + reverse("api:model-versions", kwargs={"pk": str(self.model1.id)}), + { + "tag": "TAG", + "description": "description", + "configuration": {"hello": "this is me"}, + # self.model_version3 belongs to self.model2 + "parent": str(self.model_version3.id) + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + data = response.json() + self.assertIn("id", data) + df = ModelVersion.objects.get(id=data["id"]) + + self.assertDictEqual( + data, + { + "id": str(df.id), + "model_id": str(self.model1.id), + "parent": str(self.model_version3.id), + "description": "description", + "state": ModelVersionState.Created.value, + "configuration": {"hello": "this is me"}, + "tag": "TAG", + "size": None, + "hash": None, + "created": fake_now.isoformat().replace("+00:00", "Z"), + "s3_url": None, + "s3_put_url": s3_presigned_url_mock.return_value + } + ) + + @patch("arkindex.users.managers.BaseACLManager.filter_rights") + def test_create_model_version_readable_parent_version(self, filter_rights_mock): + """ + Only model versions the user has read access to can be set as parents of a model version + """ + filter_rights_mock.return_value = Model.objects.filter(id=self.model1.id) + self.client.force_login(self.user1) + fake_now = timezone.now() + # To mock the creation date + with patch("django.utils.timezone.now") as mock_now: + mock_now.return_value = fake_now + with self.assertNumQueries(4): + response = self.client.post( + reverse("api:model-versions", kwargs={"pk": str(self.model1.id)}), + { + "tag": "TAG", + "description": "description", + "configuration": {"hello": "this is me"}, + # self.model_version3 belongs to self.model2 + "parent": str(self.model_version3.id) + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(filter_rights_mock.call_count, 1) + self.assertEqual(filter_rights_mock.call_args, call(self.user1, Model, Role.Guest.value)) + self.assertDictEqual(response.json(), { + "parent": [f'Invalid pk "{str(self.model_version3.id)}" - object does not exist.'] + }) + def test_retrieve_model_requires_login(self): with self.assertNumQueries(0): response = self.client.get(reverse("api:model-retrieve", kwargs={"pk": str(self.model2.id)})) @@ -934,8 +1010,63 @@ class TestModelAPI(FixtureAPITestCase): "size": 8, }) + @patch("arkindex.project.aws.s3.meta.client.generate_presigned_url") + def test_partial_update_any_parent_model_version(self, s3_presigned_url): + """ + A model version can have any model version as a parent, not just a version from the same model + """ + s3_presigned_url.return_value = "http://s3/get_url" + self.client.force_login(self.user2) + + with self.assertNumQueries(6): + response = self.client.patch( + reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}), + # self.model_version3 is a version of self.model2, while self.model_version1 belongs to self.model1 + {"parent": str(self.model_version1.id)}, + format="json" + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.assertDictEqual(response.json(), { + "id": str(self.model_version3.id), + "model_id": str(self.model2.id), + "created": self.model_version3.created.isoformat().replace("+00:00", "Z"), + "s3_url": "http://s3/get_url", + "s3_put_url": None, + "tag": "tagged", + "description": "", + "configuration": {}, + "parent": str(self.model_version1.id), + "state": ModelVersionState.Available.value, + "hash": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbba", + "size": 8, + }) + self.model_version3.refresh_from_db() - self.assertIsNone(self.model_version3.parent_id) + self.assertEqual(self.model_version3.parent_id, self.model_version1.id) + + @patch("arkindex.users.managers.BaseACLManager.filter_rights") + def test_partial_update_readable_parent_model_version(self, filter_rights_mock): + """ + Only model versions the user has read access to can be set as parents of a model version + """ + filter_rights_mock.return_value = Model.objects.filter(id__in=[self.model2.id, self.model3.id]) + self.client.force_login(self.user2) + + with self.assertNumQueries(4): + response = self.client.patch( + reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}), + # self.model_version3 is a version of self.model2, while self.model_version1 belongs to self.model1 + {"parent": str(self.model_version1.id)}, + format="json" + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(filter_rights_mock.call_count, 1) + self.assertEqual(filter_rights_mock.call_args, call(self.user2, Model, Role.Guest.value)) + self.assertDictEqual(response.json(), { + "parent": [f'Invalid pk "{str(self.model_version1.id)}" - object does not exist.'] + }) @patch("arkindex.training.api.get_max_level", return_value=None) def test_update_model_version_requires_contributor(self, get_max_level_mock): -- GitLab