Skip to content
Snippets Groups Projects
Commit 42b9d08e authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Erwan Rouchet
Browse files

List model versions endpoint

parent 51f6b7e1
No related branches found
No related tags found
1 merge request!1651List model versions endpoint
......@@ -88,7 +88,7 @@ from arkindex.documents.api.ml import (
from arkindex.documents.api.search import CorpusSearch
from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve
from arkindex.project.openapi import OpenApiSchemaView
from arkindex.training.api import ModelsList, ModelVersionsCreate, ModelVersionsUpdate
from arkindex.training.api import ModelsList, ModelVersionsList, ModelVersionsUpdate
from arkindex.users.api import (
CredentialsList,
CredentialsRetrieve,
......@@ -237,9 +237,9 @@ api = [
path('process/<uuid:pk>/apply/', ApplyProcessTemplate.as_view(), name='apply-process-template'),
# ML models training
path('model/<uuid:pk>/versions/', ModelVersionsCreate.as_view(), name='model-version-create'),
path('modelversion/<uuid:pk>/', ModelVersionsUpdate.as_view(), name='model-version-update'),
path('models/', ModelsList.as_view(), name='models'),
path('model/<uuid:pk>/versions/', ModelVersionsList.as_view(), name='model-versions'),
# Image management
path('image/', ImageCreate.as_view(), name='image-create'),
......
from django.db.models import Q
from django.shortcuts import get_object_or_404
from django.utils.functional import cached_property
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
from rest_framework.exceptions import PermissionDenied, ValidationError
from rest_framework.generics import CreateAPIView, ListCreateAPIView, UpdateAPIView
from rest_framework.generics import ListCreateAPIView, UpdateAPIView
from arkindex.project.mixins import TrainingModelMixin
from arkindex.project.permissions import IsVerified
from arkindex.training.models import Model, ModelVersion
from arkindex.training.models import Model, ModelVersion, ModelVersionState
from arkindex.training.serializers import (
CreateModelErrorResponseSerializer,
ModelSerializer,
ModelVersionCreateSerializer,
ModelVersionSerializer,
)
from arkindex.users.models import Role
from arkindex.users.utils import get_max_level
@extend_schema(tags=['training'])
@extend_schema_view(
get=extend_schema(
operation_id='ListModelVersions',
description=(
'List available versions of a Machine Learning model.\n\n'
'Requires a **guest** access to the model.'
),
),
post=extend_schema(
operation_id='CreateModelVersion',
description=(
......@@ -27,24 +38,48 @@ from arkindex.training.serializers import (
},
),
)
class ModelVersionsCreate(TrainingModelMixin, CreateAPIView):
class ModelVersionsList(TrainingModelMixin, ListCreateAPIView):
permission_classes = (IsVerified, )
serializer_class = ModelVersionCreateSerializer
queryset = Model.objects.none()
def get_queryset(self):
return Model.objects.filter(id=self.kwargs['pk'])
return Model.objects.filter(pk=self.kwargs['pk'])
def perform_create(self, serializer):
if not self.model_access_rights or self.model_access_rights < Role.Contributor.value:
raise PermissionDenied(detail='You need a Contributor access to the model to create a new version.')
serializer.save()
@cached_property
def model(self):
return get_object_or_404(self.get_queryset(), id=self.kwargs['pk'])
@cached_property
def model_access_rights(self):
return get_max_level(self.request.user, self.model)
def get_serializer_context(self):
context = super().get_serializer_context()
context['model'] = self.get_object()
context['model'] = self.model
context['model_rights'] = self.model_access_rights
return context
def check_object_permissions(self, request, model):
super().check_object_permissions(request, model)
def filter_queryset(self, queryset):
filters = Q(model=self.kwargs['pk'])
if not self.has_write_access(model):
raise PermissionDenied(detail='You need a Contributor access to the model to create a new version.')
access_level = self.model_access_rights
if not access_level or access_level < Role.Guest.value:
raise PermissionDenied(detail='You need a guest access to the model to list its versions.')
# Guest access allow only versions with a set tag and state = available
if access_level < Role.Contributor.value:
filters &= Q(tag__isnull=False, state=ModelVersionState.Available)
# Count all versions for models with contributor+ access
queryset = ModelVersion.objects.filter(filters).order_by('-created')
return super().filter_queryset(queryset)
@extend_schema(tags=['training'])
......
......@@ -102,8 +102,8 @@ class ModelVersionCreateSerializer(ModelVersionSerializer):
tag = serializers.CharField(allow_null=True, max_length=50, required=False, default=None)
class Meta(ModelVersionSerializer.Meta):
fields = ModelVersionSerializer.Meta.fields + ('s3_put_url',)
read_only_fields = ModelVersionSerializer.Meta.read_only_fields + ('s3_put_url',)
fields = ModelVersionSerializer.Meta.fields + ('s3_put_url', 'created', )
read_only_fields = ModelVersionSerializer.Meta.read_only_fields + ('s3_put_url', 'created', )
@extend_schema_field(serializers.CharField(allow_null=True))
def get_s3_put_url(self, obj):
......@@ -111,6 +111,11 @@ class ModelVersionCreateSerializer(ModelVersionSerializer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# When used for CreateModelVersion, the API adds the provided model to the context
# The API adds the provided model to the context
if self.context.get('model') is not None:
self.fields['parent'].queryset = ModelVersion.objects.filter(model_id=self.context['model'].id)
# If user doesn't have a contributor access to the model, don't show s3_url and s3_put_url
access_level = self.context.get('model_rights')
if not access_level or access_level < Role.Contributor.value:
del self.fields['s3_url']
del self.fields['s3_put_url']
......@@ -3,6 +3,7 @@
from unittest.mock import patch
from django.urls import reverse
from django.utils import timezone
from rest_framework import status
from arkindex.project.tests import FixtureAPITestCase
......@@ -10,6 +11,10 @@ from arkindex.training.models import Model, ModelVersion, ModelVersionState
from arkindex.users.models import Group, Right, Role, User
def _format_datetime(date):
return str(date.isoformat().replace('+00:00', 'Z')).replace(' ', 'T')
class TestModelAPI(FixtureAPITestCase):
"""
Test model and model version api
......@@ -78,7 +83,7 @@ class TestModelAPI(FixtureAPITestCase):
self.client.force_login(user)
request = self.build_model_version_create_request()
with self.assertNumQueries(2):
response = self.client.post(reverse("api:model-version-create", kwargs={"pk": str(self.model2.id)}), request, format='json')
response = self.client.post(reverse("api:model-versions", kwargs={"pk": str(self.model2.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."})
......@@ -88,8 +93,8 @@ class TestModelAPI(FixtureAPITestCase):
"""
self.client.force_login(self.user1)
request = self.build_model_version_create_request()
with self.assertNumQueries(6):
response = self.client.post(reverse("api:model-version-create", kwargs={"pk": str(self.model2.id)}), request, format='json')
with self.assertNumQueries(7):
response = self.client.post(reverse("api:model-versions", kwargs={"pk": str(self.model2.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the model to create a new version."})
......@@ -101,13 +106,17 @@ class TestModelAPI(FixtureAPITestCase):
s3_presigned_url_mock.return_value = 'http://s3/upload_put_url'
self.client.force_login(self.user1)
request = self.build_model_version_create_request()
with self.assertNumQueries(8):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json')
fake_now = timezone.now()
# To mock the creation date
with patch('django.utils.timezone.now') as mock_now:
mock_now.return_value = fake_now
with self.assertNumQueries(8):
response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
data = response.json()
self.assertIn('id', data)
df = ModelVersion.objects.get(id=data['id'])
self.maxDiff = None
self.assertDictEqual(
data,
......@@ -121,6 +130,7 @@ class TestModelAPI(FixtureAPITestCase):
'tag': None,
'size': request['size'],
'hash': request['hash'],
'created': _format_datetime(fake_now),
's3_url': s3_presigned_url_mock.return_value,
's3_put_url': s3_presigned_url_mock.return_value
}
......@@ -134,7 +144,7 @@ class TestModelAPI(FixtureAPITestCase):
request = self.build_model_version_create_request()
request['tag'] = ''
with self.assertNumQueries(7):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json')
response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"tag": ["This field may not be blank."]})
......@@ -151,8 +161,12 @@ class TestModelAPI(FixtureAPITestCase):
'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0',
'size': 8,
}
with self.assertNumQueries(9):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json')
fake_now = timezone.now()
# To mock the creation date
with patch('django.utils.timezone.now') as mock_now:
mock_now.return_value = fake_now
with self.assertNumQueries(9):
response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
data = response.json()
self.assertIn('id', data)
......@@ -170,6 +184,7 @@ class TestModelAPI(FixtureAPITestCase):
'tag': request['tag'],
'size': request['size'],
'hash': request['hash'],
'created': _format_datetime(fake_now),
's3_url': s3_presigned_url_mock.return_value,
's3_put_url': s3_presigned_url_mock.return_value
}
......@@ -183,7 +198,7 @@ class TestModelAPI(FixtureAPITestCase):
request = self.build_model_version_create_request()
request['tag'] = self.model_version2.tag
with self.assertNumQueries(8):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json')
response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]})
......@@ -195,7 +210,7 @@ class TestModelAPI(FixtureAPITestCase):
self.client.force_login(self.user1)
request = {'tag': 'production', 'hash': self.model_version5.hash, 'size': 32}
with self.assertNumQueries(7):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json')
response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"hash": ["A version for this model with this hash already exists."]})
......@@ -207,7 +222,7 @@ class TestModelAPI(FixtureAPITestCase):
self.client.force_login(self.user1)
request = {'tag': 'production', 'hash': self.model_version2.hash, 'size': 32}
with self.assertNumQueries(7):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json')
response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(
response.json()['hash'],
......@@ -219,6 +234,7 @@ class TestModelAPI(FixtureAPITestCase):
'state': self.model_version2.state.value,
'configuration': self.model_version2.configuration,
'tag': self.model_version2.tag,
'created': _format_datetime(self.model_version2.created),
'size': str(self.model_version2.size),
'hash': self.model_version2.hash,
's3_url': self.model_version2.s3_url,
......@@ -574,6 +590,7 @@ class TestModelAPI(FixtureAPITestCase):
response = self.client.get(reverse('api:models'))
self.assertEqual(response.status_code, status.HTTP_200_OK)
models = response.json()['results']
self.assertEqual(len(models), 2)
self.assertListEqual([model['id'] for model in models], [str(self.model1.id), str(self.model2.id)])
......@@ -588,3 +605,38 @@ class TestModelAPI(FixtureAPITestCase):
self.assertEqual(len(models), 1)
self.assertListEqual([model['id'] for model in models], [str(self.model2.id)])
def test_list_model_versions_requires_logged_in(self):
"""To list a model's versions, you need to be logged in.
"""
response = self.client.get(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "Authentication credentials were not provided."})
def test_list_model_versions_requires_guest_access(self):
"""To view a model's version, you need guest access on the model
"""
self.client.force_login(self.user1)
with self.assertNumQueries(6):
response = self.client.get(reverse('api:model-versions', kwargs={"pk": str(self.model2.id)}))
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": 'You need a guest access to the model to list its versions.'})
def test_list_model_versions_low_access(self):
"""With only guest access rights on a model, you only see the available versions with a set tag
"""
self.client.force_login(self.user3)
with self.assertNumQueries(8):
response = self.client.get(reverse('api:model-versions', kwargs={"pk": str(self.model2.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertListEqual([version["id"] for version in response.json()['results']], [str(self.model_version3.id)])
def test_list_model_versions_full_access(self):
"""With contributor access rights, you see every versions.
"""
self.client.force_login(self.user2)
with self.assertNumQueries(8):
response = self.client.get(reverse('api:model-versions', kwargs={"pk": str(self.model2.id)}))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertListEqual([version["id"] for version in response.json()['results']], [str(self.model_version4.id), str(self.model_version3.id)])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment