diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 55acaf6f0d3311448a155766d522744653699a99..353dc6e9286760c278d8895c12e6f5d9bae11135 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -307,10 +307,15 @@ class ValidateModelVersion(TrainingModelMixin, GenericAPIView): # Set the current model version as erroneous and return the available one instance.state = ModelVersionState.Error instance.save(update_fields=["state"]) + # Set context + context = { + **self.get_serializer_context(), + "is_contributor": True + } return Response( ModelVersionSerializer( existing_model_version, - context={"is_contributor": True, "model": instance}, + context=context, ).data, status=status.HTTP_409_CONFLICT, ) diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 4cff4c8f0f2937a4cb4879f6d10902394c6375de..ab1e300b586aceb1a7a6761b169349cac2524e12 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -213,7 +213,7 @@ class ModelVersionSerializer(serializers.ModelSerializer): else: model = self.context.get("model") if model: - qs = ModelVersion.objects.filter(model_id=model.id) + qs = ModelVersion.objects.filter(model__in=Model.objects.readable(self.context["request"].user)) if getattr(self.instance, "id", None): qs = qs.exclude(id=self.instance.id) self.fields["parent"].queryset = qs diff --git a/arkindex/training/tests/test_model_api.py b/arkindex/training/tests/test_model_api.py index 7a264b2f3f26fa529d8000960acfe5e8b73f84ee..303a25e5201a99049219c6c8a69b9227abb13af4 100644 --- a/arkindex/training/tests/test_model_api.py +++ b/arkindex/training/tests/test_model_api.py @@ -266,6 +266,82 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]}) + @patch("arkindex.project.aws.s3.meta.client.generate_presigned_url") + def test_create_model_version_any_parent_model_version(self, s3_presigned_url_mock): + """ + Any readable model version can be set as parent of a model version, not just versions of the same model + """ + self.client.force_login(self.user1) + s3_presigned_url_mock.return_value = "http://s3/upload_put_url" + fake_now = timezone.now() + # To mock the creation date + with patch("django.utils.timezone.now") as mock_now: + mock_now.return_value = fake_now + with self.assertNumQueries(6): + response = self.client.post( + reverse("api:model-versions", kwargs={"pk": str(self.model1.id)}), + { + "tag": "TAG", + "description": "description", + "configuration": {"hello": "this is me"}, + # self.model_version3 belongs to self.model2 + "parent": str(self.model_version3.id) + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + data = response.json() + self.assertIn("id", data) + df = ModelVersion.objects.get(id=data["id"]) + + self.assertDictEqual( + data, + { + "id": str(df.id), + "model_id": str(self.model1.id), + "parent": str(self.model_version3.id), + "description": "description", + "state": ModelVersionState.Created.value, + "configuration": {"hello": "this is me"}, + "tag": "TAG", + "size": None, + "hash": None, + "created": fake_now.isoformat().replace("+00:00", "Z"), + "s3_url": None, + "s3_put_url": s3_presigned_url_mock.return_value + } + ) + + @patch("arkindex.users.managers.BaseACLManager.filter_rights") + def test_create_model_version_readable_parent_version(self, filter_rights_mock): + """ + Only model versions the user has read access to can be set as parents of a model version + """ + filter_rights_mock.return_value = Model.objects.filter(id=self.model1.id) + self.client.force_login(self.user1) + fake_now = timezone.now() + # To mock the creation date + with patch("django.utils.timezone.now") as mock_now: + mock_now.return_value = fake_now + with self.assertNumQueries(4): + response = self.client.post( + reverse("api:model-versions", kwargs={"pk": str(self.model1.id)}), + { + "tag": "TAG", + "description": "description", + "configuration": {"hello": "this is me"}, + # self.model_version3 belongs to self.model2 + "parent": str(self.model_version3.id) + }, + format="json", + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(filter_rights_mock.call_count, 1) + self.assertEqual(filter_rights_mock.call_args, call(self.user1, Model, Role.Guest.value)) + self.assertDictEqual(response.json(), { + "parent": [f'Invalid pk "{str(self.model_version3.id)}" - object does not exist.'] + }) + def test_retrieve_model_requires_login(self): with self.assertNumQueries(0): response = self.client.get(reverse("api:model-retrieve", kwargs={"pk": str(self.model2.id)})) @@ -934,8 +1010,63 @@ class TestModelAPI(FixtureAPITestCase): "size": 8, }) + @patch("arkindex.project.aws.s3.meta.client.generate_presigned_url") + def test_partial_update_any_parent_model_version(self, s3_presigned_url): + """ + A model version can have any model version as a parent, not just a version from the same model + """ + s3_presigned_url.return_value = "http://s3/get_url" + self.client.force_login(self.user2) + + with self.assertNumQueries(6): + response = self.client.patch( + reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}), + # self.model_version3 is a version of self.model2, while self.model_version1 belongs to self.model1 + {"parent": str(self.model_version1.id)}, + format="json" + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.assertDictEqual(response.json(), { + "id": str(self.model_version3.id), + "model_id": str(self.model2.id), + "created": self.model_version3.created.isoformat().replace("+00:00", "Z"), + "s3_url": "http://s3/get_url", + "s3_put_url": None, + "tag": "tagged", + "description": "", + "configuration": {}, + "parent": str(self.model_version1.id), + "state": ModelVersionState.Available.value, + "hash": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbba", + "size": 8, + }) + self.model_version3.refresh_from_db() - self.assertIsNone(self.model_version3.parent_id) + self.assertEqual(self.model_version3.parent_id, self.model_version1.id) + + @patch("arkindex.users.managers.BaseACLManager.filter_rights") + def test_partial_update_readable_parent_model_version(self, filter_rights_mock): + """ + Only model versions the user has read access to can be set as parents of a model version + """ + filter_rights_mock.return_value = Model.objects.filter(id__in=[self.model2.id, self.model3.id]) + self.client.force_login(self.user2) + + with self.assertNumQueries(4): + response = self.client.patch( + reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}), + # self.model_version3 is a version of self.model2, while self.model_version1 belongs to self.model1 + {"parent": str(self.model_version1.id)}, + format="json" + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + self.assertEqual(filter_rights_mock.call_count, 1) + self.assertEqual(filter_rights_mock.call_args, call(self.user2, Model, Role.Guest.value)) + self.assertDictEqual(response.json(), { + "parent": [f'Invalid pk "{str(self.model_version1.id)}" - object does not exist.'] + }) @patch("arkindex.training.api.get_max_level", return_value=None) def test_update_model_version_requires_contributor(self, get_max_level_mock):