Skip to content
Snippets Groups Projects
Commit e3a9072f authored by Erwan Rouchet's avatar Erwan Rouchet
Browse files

Merge branch 'delete-model-version' into 'master'

Destroy model version endpoint

Closes #994

See merge request !1662
parents e1355ce7 dcad384f
No related branches found
Tags 1.2.3-rc1
1 merge request!1662Destroy model version endpoint
......@@ -89,7 +89,7 @@ from arkindex.documents.api.ml import (
from arkindex.documents.api.search import CorpusSearch
from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve
from arkindex.project.openapi import OpenApiSchemaView
from arkindex.training.api import ModelsList, ModelVersionsList, ModelVersionsUpdate
from arkindex.training.api import ModelsList, ModelVersionsList, ModelVersionsRetrieve
from arkindex.users.api import (
CredentialsList,
CredentialsRetrieve,
......@@ -239,7 +239,7 @@ api = [
path('process/<uuid:pk>/apply/', ApplyProcessTemplate.as_view(), name='apply-process-template'),
# ML models training
path('modelversion/<uuid:pk>/', ModelVersionsUpdate.as_view(), name='model-version-update'),
path('modelversion/<uuid:pk>/', ModelVersionsRetrieve.as_view(), name='model-version-retrieve'),
path('models/', ModelsList.as_view(), name='models'),
path('model/<uuid:pk>/versions/', ModelVersionsList.as_view(), name='model-versions'),
......
......@@ -3,7 +3,7 @@ from django.shortcuts import get_object_or_404
from django.utils.functional import cached_property
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
from rest_framework.exceptions import PermissionDenied, ValidationError
from rest_framework.generics import ListCreateAPIView, UpdateAPIView
from rest_framework.generics import ListCreateAPIView, RetrieveUpdateDestroyAPIView
from arkindex.project.mixins import TrainingModelMixin
from arkindex.project.permissions import IsVerified
......@@ -76,7 +76,7 @@ class ModelVersionsList(TrainingModelMixin, ListCreateAPIView):
if access_level < Role.Contributor.value:
filters &= Q(tag__isnull=False, state=ModelVersionState.Available)
# Count all versions for models with contributor+ access
# Order version by creation date
queryset = ModelVersion.objects.filter(filters).order_by('-created')
return super().filter_queryset(queryset)
......@@ -93,8 +93,13 @@ class ModelVersionsList(TrainingModelMixin, ListCreateAPIView):
patch=extend_schema(
description='Partially update a version of a Machine Learning model.\n\nRequires a **contributor** access to the model.'
),
delete=extend_schema(
description='Delete a model version.\n\nRequires an **admin** access on the related model.'
)
)
class ModelVersionsUpdate(TrainingModelMixin, UpdateAPIView):
class ModelVersionsRetrieve(TrainingModelMixin, RetrieveUpdateDestroyAPIView):
"""Retrieve a version of Machine Learning model
"""
permission_classes = (IsVerified, )
serializer_class = ModelVersionSerializer
queryset = ModelVersion.objects.none()
......@@ -106,13 +111,41 @@ class ModelVersionsUpdate(TrainingModelMixin, UpdateAPIView):
def get_serializer_context(self):
context = super().get_serializer_context()
context['model'] = self.get_object().model
context['model'] = self.model
return context
@cached_property
def model(self):
return get_object_or_404(self.get_queryset(), id=self.kwargs['pk']).model
@cached_property
def model_access_rights(self):
return get_max_level(self.request.user, self.model)
def check_object_permissions(self, request, model_version):
super().check_object_permissions(request, model_version.model)
if not self.has_write_access(model_version.model):
access_level = self.model_access_rights
if model_version.state == ModelVersionState.Available and model_version.tag is not None:
needed_level, error_msg = Role.Guest.value, 'Guest'
else:
needed_level, error_msg = Role.Contributor.value, 'Contributor'
if not access_level or access_level < needed_level:
raise PermissionDenied(detail=f'You need a {error_msg} access to the model to retrieve this version.')
return super().check_object_permissions(request, model_version)
def perform_update(self, serializer):
access_level = self.model_access_rights
if not access_level or access_level < Role.Contributor.value:
raise PermissionDenied(detail='You need a Contributor access to the model to edit this version.')
return super().perform_update(serializer)
def perform_destroy(self, instance):
access_level = self.model_access_rights
if not access_level or access_level < Role.Admin.value:
raise PermissionDenied(detail='You need an Admin access to the model to destroy this version.')
super().perform_destroy(instance)
@extend_schema(tags=['training'])
......
......@@ -26,6 +26,24 @@ def _deserialize_model(model, access_rights):
}
def _deserialize_model_version(model_version):
# Only stringify if we get an ID
parent = str(model_version.parent) if model_version.parent else model_version.parent
return {
'id': str(model_version.id),
'model_id': str(model_version.model_id),
'parent': parent,
'description': model_version.description,
'tag': model_version.tag,
'hash': model_version.hash,
'archive_hash': model_version.archive_hash,
'state': model_version.state.value,
'size': model_version.size,
'configuration': model_version.configuration,
's3_url': model_version.s3_url,
}
class TestModelAPI(FixtureAPITestCase):
"""
Test model and model version api
......@@ -268,13 +286,18 @@ class TestModelAPI(FixtureAPITestCase):
}
)
def test_partial_update_model_version_requires_contributor(self):
@patch('arkindex.project.aws.s3.Object')
@patch('arkindex.project.aws.S3FileMixin.exists')
def test_partial_update_model_version_requires_contributor(self, exists, s3_object):
"""
Can't partial update a model version as guest
"""
s3_object().content_length = self.model_version3.size
s3_object().e_tag = self.model_version3.archive_hash
exists.return_value = True
self.client.force_login(self.user3)
with self.assertNumQueries(6):
response = self.client.patch(reverse("api:model-version-update", kwargs={"pk": str(self.model_version1.id)}), {'state': 'available'})
with self.assertNumQueries(9):
response = self.client.patch(reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}), {'state': 'available'})
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the model to edit this version."})
......@@ -295,8 +318,8 @@ class TestModelAPI(FixtureAPITestCase):
"state" : "available",
"configuration": {"parameter1": "value1"}
}
with self.assertNumQueries(11):
response = self.client.patch(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), request, format='json')
with self.assertNumQueries(10):
response = self.client.patch(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
......@@ -319,8 +342,8 @@ class TestModelAPI(FixtureAPITestCase):
"""
self.client.force_login(self.user1)
exists.return_value = False
with self.assertNumQueries(8):
response = self.client.patch(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), {
with self.assertNumQueries(7):
response = self.client.patch(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), {
'state': 'available'
})
......@@ -335,8 +358,8 @@ class TestModelAPI(FixtureAPITestCase):
exists.return_value = True
s3_object().content_length = self.model_version1.size + 1
self.client.force_login(self.user1)
with self.assertNumQueries(8):
response = self.client.patch(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), {
with self.assertNumQueries(7):
response = self.client.patch(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), {
'state': 'available',
})
......@@ -352,8 +375,8 @@ class TestModelAPI(FixtureAPITestCase):
s3_object().content_length = self.model_version1.size
s3_object().e_tag = f'"{self.model_version2.hash}"'
self.client.force_login(self.user1)
with self.assertNumQueries(9):
response = self.client.patch(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), {
with self.assertNumQueries(8):
response = self.client.patch(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), {
'state': 'available',
})
......@@ -376,21 +399,26 @@ class TestModelAPI(FixtureAPITestCase):
"tag": self.model_version2.tag,
"state": "available",
}
with self.assertNumQueries(10):
response = self.client.patch(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), request, format='json')
with self.assertNumQueries(9):
response = self.client.patch(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]})
def test_update_model_version_requires_contributor(self):
@patch('arkindex.project.aws.s3.Object')
@patch('arkindex.project.aws.S3FileMixin.exists')
def test_update_model_version_requires_contributor(self, exists, s3_object):
"""
Can't update a model version with guest access rights to the model
"""
s3_object().content_length = self.model_version3.size
s3_object().e_tag = self.model_version3.archive_hash
exists.return_value = True
self.client.force_login(self.user3)
request = self.build_model_version_update_request()
with self.assertNumQueries(6):
response = self.client.put(reverse("api:model-version-update", kwargs={"pk": str(self.model_version1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
with self.assertNumQueries(9):
response = self.client.put(reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version3.id)}), request, format='json')
# self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the model to edit this version."})
def test_update_model_version_requires_all_parameters(self):
......@@ -398,8 +426,8 @@ class TestModelAPI(FixtureAPITestCase):
Update endpoint requires every parameter
"""
self.client.force_login(self.user1)
with self.assertNumQueries(8):
response = self.client.put(reverse("api:model-version-update", kwargs={"pk": str(self.model_version1.id)}), {})
with self.assertNumQueries(7):
response = self.client.put(reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version1.id)}), {})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {'description': ['This field is required.'], 'state': ['This field is required.'], 'configuration': ['This field is required.']})
......@@ -418,7 +446,7 @@ class TestModelAPI(FixtureAPITestCase):
self.client.force_login(self.user1)
params = self.build_model_version_update_request()
for param, value in params.items():
response = self.client.put(reverse("api:model-version-update", kwargs={"pk": str(self.model_version1.id)}), {param: value}, format='json')
response = self.client.put(reverse("api:model-version-retrieve", kwargs={"pk": str(self.model_version1.id)}), {param: value}, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {other_param: ['This field is required.'] for other_param in params.keys() if other_param != param}, f'{param}')
......@@ -439,8 +467,8 @@ class TestModelAPI(FixtureAPITestCase):
"state": "available",
"configuration" : {"parameter1": "value1"}
}
with self.assertNumQueries(11):
response = self.client.put(reverse('api:model-version-update', kwargs={"pk": str(self.model_version2.id)}), request, format='json')
with self.assertNumQueries(10):
response = self.client.put(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version2.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), {
"id": str(self.model_version2.id),
......@@ -463,8 +491,8 @@ class TestModelAPI(FixtureAPITestCase):
self.client.force_login(self.user1)
exists.return_value = False
request = self.build_model_version_update_request()
with self.assertNumQueries(8):
response = self.client.put(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), request, format='json')
with self.assertNumQueries(7):
response = self.client.put(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"state": ['Archive has not been uploaded']})
......@@ -478,8 +506,8 @@ class TestModelAPI(FixtureAPITestCase):
s3_object().content_length = self.model_version1.size + 1
self.client.force_login(self.user1)
request = self.build_model_version_update_request()
with self.assertNumQueries(8):
response = self.client.put(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), request, format='json')
with self.assertNumQueries(7):
response = self.client.put(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"state": [f'Uploaded file size is {self.model_version1.size + 1} bytes, expected {self.model_version1.size} bytes']})
......@@ -494,8 +522,8 @@ class TestModelAPI(FixtureAPITestCase):
s3_object().e_tag = f'"{self.model_version2.archive_hash}"'
self.client.force_login(self.user1)
request = self.build_model_version_update_request()
with self.assertNumQueries(9):
response = self.client.put(reverse('api:model-version-update', kwargs={"pk": str(self.model_version1.id)}), request, format='json')
with self.assertNumQueries(8):
response = self.client.put(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"state": ['MD5 hashes do not match']})
......@@ -673,3 +701,57 @@ class TestModelAPI(FixtureAPITestCase):
response = self.client.get(reverse('api:model-versions', kwargs={"pk": str(self.model2.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertListEqual([version["id"] for version in response.json()['results']], [str(self.model_version4.id), str(self.model_version3.id)])
def test_destroy_model_versions_requires_admin(self):
"""To destroy a model version, you need admin rights on the model.
"""
self.client.force_login(self.user2)
with self.assertNumQueries(7):
response = self.client.delete(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You need an Admin access to the model to destroy this version."})
def test_destroy_model_versions(self):
"""To destroy a model version, you need admin rights on the model.
"""
self.client.force_login(self.user1)
with self.assertNumQueries(9):
response = self.client.delete(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version1.id)}))
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
def test_retrieve_model_versions_require_guest(self):
"""To retrieve a model version with a set tag and state==Available, you need guest rights on the model.
"""
self.client.force_login(self.user1)
with self.assertNumQueries(7):
response = self.client.get(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version3.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You need a Guest access to the model to retrieve this version."})
def test_retrieve_model_versions_tag_available(self):
"""Retrieve a model version with a set tag and state==Available with guest rights on the model.
"""
self.client.force_login(self.user3)
with self.assertNumQueries(7):
response = self.client.get(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version3.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertDictEqual(response.json(), _deserialize_model_version(self.model_version3))
def test_retrieve_model_versions_require_contributor(self):
"""To retrieve a model version with no set tag or state!=Available, you need contributor rights on the model.
"""
self.client.force_login(self.user3)
with self.assertNumQueries(7):
response = self.client.get(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version4.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the model to retrieve this version."})
def test_retrieve_model_versions(self):
"""Retrieve a model version with no set tag or state!=Available with contributor rights on the model.
"""
self.client.force_login(self.user2)
with self.assertNumQueries(7):
response = self.client.get(reverse('api:model-version-retrieve', kwargs={"pk": str(self.model_version4.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.maxDiff = None
self.assertDictEqual(response.json(), _deserialize_model_version(self.model_version4))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment