Skip to content
Snippets Groups Projects

Allow setting any readable model version as parent of a model version, not just versions from the same model

Merged ml bonhomme requested to merge any-model-parent into master
All threads resolved!
Files
3
@@ -266,6 +266,82 @@ class TestModelAPI(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]})
@patch("arkindex.project.aws.s3.meta.client.generate_presigned_url")
def test_create_model_version_any_parent_model_version(self, s3_presigned_url_mock):
"""
Any readable model version can be set as parent of a model version, not just versions of the same model
"""
self.client.force_login(self.user1)
s3_presigned_url_mock.return_value = "http://s3/upload_put_url"
fake_now = timezone.now()
# To mock the creation date
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = fake_now
with self.assertNumQueries(6):
response = self.client.post(
reverse("api:model-versions", kwargs={"pk": str(self.model1.id)}),
{
"tag": "TAG",
"description": "description",
"configuration": {"hello": "this is me"},
# self.model_version3 belongs to self.model2
"parent": str(self.model_version3.id)
},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
data = response.json()
self.assertIn("id", data)
df = ModelVersion.objects.get(id=data["id"])
self.assertDictEqual(
data,
{
"id": str(df.id),
"model_id": str(self.model1.id),
"parent": str(self.model_version3.id),
"description": "description",
"state": ModelVersionState.Created.value,
"configuration": {"hello": "this is me"},
"tag": "TAG",
"size": None,
"hash": None,
"created": fake_now.isoformat().replace("+00:00", "Z"),
"s3_url": None,
"s3_put_url": s3_presigned_url_mock.return_value
}
)
@patch("arkindex.users.managers.BaseACLManager.filter_rights")
def test_create_model_version_readable_parent_version(self, filter_rights_mock):
"""
Only model versions the user has read access to can be set as parents of a model version
"""
filter_rights_mock.return_value = Model.objects.filter(id=self.model1.id)
self.client.force_login(self.user1)
fake_now = timezone.now()
# To mock the creation date
with patch("django.utils.timezone.now") as mock_now:
mock_now.return_value = fake_now
with self.assertNumQueries(4):
response = self.client.post(
reverse("api:model-versions", kwargs={"pk": str(self.model1.id)}),
{
"tag": "TAG",
"description": "description",
"configuration": {"hello": "this is me"},
# self.model_version3 belongs to self.model2
"parent": str(self.model_version3.id)
},
format="json",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(filter_rights_mock.call_count, 1)
self.assertEqual(filter_rights_mock.call_args, call(self.user1, Model, Role.Guest.value))
self.assertDictEqual(response.json(), {
"parent": [f'Invalid pk "{str(self.model_version3.id)}" - object does not exist.']
})
def test_retrieve_model_requires_login(self):
with self.assertNumQueries(0):
response = self.client.get(reverse("api:model-retrieve", kwargs={"pk": str(self.model2.id)}))
@@ -934,8 +1010,63 @@ class TestModelAPI(FixtureAPITestCase):
"size": 8,
})
@patch("arkindex.project.aws.s3.meta.client.generate_presigned_url")
def test_partial_update_any_parent_model_version(self, s3_presigned_url):
"""
A model version can have any model version as a parent, not just a version from the same model
"""
s3_presigned_url.return_value = "http://s3/get_url"
self.client.force_login(self.user2)
with self.assertNumQueries(6):
response = self.client.patch(
reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}),
# self.model_version3 is a version of self.model2, while self.model_version1 belongs to self.model1
{"parent": str(self.model_version1.id)},
format="json"
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"id": str(self.model_version3.id),
"model_id": str(self.model2.id),
"created": self.model_version3.created.isoformat().replace("+00:00", "Z"),
"s3_url": "http://s3/get_url",
"s3_put_url": None,
"tag": "tagged",
"description": "",
"configuration": {},
"parent": str(self.model_version1.id),
"state": ModelVersionState.Available.value,
"hash": "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbba",
"size": 8,
})
self.model_version3.refresh_from_db()
self.assertIsNone(self.model_version3.parent_id)
self.assertEqual(self.model_version3.parent_id, self.model_version1.id)
@patch("arkindex.users.managers.BaseACLManager.filter_rights")
def test_partial_update_readable_parent_model_version(self, filter_rights_mock):
"""
Only model versions the user has read access to can be set as parents of a model version
"""
filter_rights_mock.return_value = Model.objects.filter(id__in=[self.model2.id, self.model3.id])
self.client.force_login(self.user2)
with self.assertNumQueries(4):
response = self.client.patch(
reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}),
# self.model_version3 is a version of self.model2, while self.model_version1 belongs to self.model1
{"parent": str(self.model_version1.id)},
format="json"
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(filter_rights_mock.call_count, 1)
self.assertEqual(filter_rights_mock.call_args, call(self.user2, Model, Role.Guest.value))
self.assertDictEqual(response.json(), {
"parent": [f'Invalid pk "{str(self.model_version1.id)}" - object does not exist.']
})
@patch("arkindex.training.api.get_max_level", return_value=None)
def test_update_model_version_requires_contributor(self, get_max_level_mock):
Loading