Skip to content
Snippets Groups Projects
Commit bc0c2a69 authored by ml bonhomme's avatar ml bonhomme :bee: Committed by Erwan Rouchet
Browse files

Allow setting any readable model version as parent of a model version, not...

Allow setting any readable model version as parent of a model version, not just versions from the same model
parent acf2b82c
No related branches found
No related tags found
1 merge request!2255Allow setting any readable model version as parent of a model version, not just versions from the same model
......@@ -307,10 +307,15 @@ class ValidateModelVersion(TrainingModelMixin, GenericAPIView):
# Set the current model version as erroneous and return the available one
instance.state = ModelVersionState.Error
instance.save(update_fields=["state"])
# Set context
context = {
**self.get_serializer_context(),
"is_contributor": True
}
return Response(
ModelVersionSerializer(
existing_model_version,
context={"is_contributor": True, "model": instance},
context=context,
).data,
status=status.HTTP_409_CONFLICT,
)
......
......@@ -213,7 +213,7 @@ class ModelVersionSerializer(serializers.ModelSerializer):
else:
model = self.context.get("model")
if model:
qs = ModelVersion.objects.filter(model_id=model.id)
qs = ModelVersion.objects.filter(model__in=Model.objects.readable(self.context["request"].user))
if getattr(self.instance, "id", None):
qs = qs.exclude(id=self.instance.id)
self.fields["parent"].queryset = qs
......
......@@ -266,6 +266,82 @@ class TestModelAPI(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]})
@patch("arkindex.project.aws.s3.meta.client.generate_presigned_url")
def test_create_model_version_any_parent_model_version(self, s3_presigned_url_mock):
"""
Any readable model version can be set as parent of a model version, not just versions of the same model
"""
self.client.force_login(self.user1)
s3_presigned_url_mock.return_value = "http://s3/upload_put_url"
fake_now = timezone.now()
# To mock the creation date
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = fake_now
with self.assertNumQueries(6):
response = self.client.post(
reverse("api:model-versions", kwargs={"pk": str(self.model1.id)}),
{
"tag": "TAG",
"description": "description",
"configuration": {"hello": "this is me"},
# self.model_version3 belongs to self.model2
"parent": str(self.model_version3.id)
},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
data = response.json()
self.assertIn("id", data)
df = ModelVersion.objects.get(id=data["id"])
self.assertDictEqual(
data,
{
"id": str(df.id),
"model_id": str(self.model1.id),
"parent": str(self.model_version3.id),
"description": "description",
"state": ModelVersionState.Created.value,
"configuration": {"hello": "this is me"},
"tag": "TAG",
"size": None,
"hash": None,
"created": fake_now.isoformat().replace("+00:00", "Z"),
"s3_url": None,
"s3_put_url": s3_presigned_url_mock.return_value
}
)
@patch("arkindex.users.managers.BaseACLManager.filter_rights")
def test_create_model_version_readable_parent_version(self, filter_rights_mock):
"""
Only model versions the user has read access to can be set as parents of a model version
"""
filter_rights_mock.return_value = Model.objects.filter(id=self.model1.id)
self.client.force_login(self.user1)
fake_now = timezone.now()
# To mock the creation date
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = fake_now
with self.assertNumQueries(4):
response = self.client.post(
reverse("api:model-versions", kwargs={"pk": str(self.model1.id)}),
{
"tag": "TAG",
"description": "description",
"configuration": {"hello": "this is me"},
# self.model_version3 belongs to self.model2
"parent": str(self.model_version3.id)
},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(filter_rights_mock.call_count, 1)
self.assertEqual(filter_rights_mock.call_args, call(self.user1, Model, Role.Guest.value))
self.assertDictEqual(response.json(), {
"parent": [f'Invalid pk "{str(self.model_version3.id)}" - object does not exist.']
})
def test_retrieve_model_requires_login(self):
with self.assertNumQueries(0):
response = self.client.get(reverse("api:model-retrieve", kwargs={"pk": str(self.model2.id)}))
......@@ -934,8 +1010,63 @@ class TestModelAPI(FixtureAPITestCase):
"size": 8,
})
@patch("arkindex.project.aws.s3.meta.client.generate_presigned_url")
def test_partial_update_any_parent_model_version(self, s3_presigned_url):
"""
A model version can have any model version as a parent, not just a version from the same model
"""
s3_presigned_url.return_value = "http://s3/get_url"
self.client.force_login(self.user2)
with self.assertNumQueries(6):
response = self.client.patch(
reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}),
# self.model_version3 is a version of self.model2, while self.model_version1 belongs to self.model1
{"parent": str(self.model_version1.id)},
format="json"
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"id": str(self.model_version3.id),
"model_id": str(self.model2.id),
"created": self.model_version3.created.isoformat().replace("+00:00", "Z"),
"s3_url": "http://s3/get_url",
"s3_put_url": None,
"tag": "tagged",
"description": "",
"configuration": {},
"parent": str(self.model_version1.id),
"state": ModelVersionState.Available.value,
"hash": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbba",
"size": 8,
})
self.model_version3.refresh_from_db()
self.assertIsNone(self.model_version3.parent_id)
self.assertEqual(self.model_version3.parent_id, self.model_version1.id)
@patch("arkindex.users.managers.BaseACLManager.filter_rights")
def test_partial_update_readable_parent_model_version(self, filter_rights_mock):
"""
Only model versions the user has read access to can be set as parents of a model version
"""
filter_rights_mock.return_value = Model.objects.filter(id__in=[self.model2.id, self.model3.id])
self.client.force_login(self.user2)
with self.assertNumQueries(4):
response = self.client.patch(
reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}),
# self.model_version3 is a version of self.model2, while self.model_version1 belongs to self.model1
{"parent": str(self.model_version1.id)},
format="json"
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(filter_rights_mock.call_count, 1)
self.assertEqual(filter_rights_mock.call_args, call(self.user2, Model, Role.Guest.value))
self.assertDictEqual(response.json(), {
"parent": [f'Invalid pk "{str(self.model_version1.id)}" - object does not exist.']
})
@patch("arkindex.training.api.get_max_level", return_value=None)
def test_update_model_version_requires_contributor(self, get_max_level_mock):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment