diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index f9c7c58e08164ff32b26b104f587350f9b886238..7a9b35e74f08661dfba27b3a3dd65a1e22788390 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -88,7 +88,7 @@ from arkindex.documents.api.ml import ( from arkindex.documents.api.search import CorpusSearch from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve from arkindex.project.openapi import OpenApiSchemaView -from arkindex.training.api import ModelsList, ModelVersionsCreate, ModelVersionsUpdate +from arkindex.training.api import ModelsList, ModelVersionsList, ModelVersionsUpdate from arkindex.users.api import ( CredentialsList, CredentialsRetrieve, @@ -237,9 +237,9 @@ api = [ path('process/<uuid:pk>/apply/', ApplyProcessTemplate.as_view(), name='apply-process-template'), # ML models training - path('model/<uuid:pk>/versions/', ModelVersionsCreate.as_view(), name='model-version-create'), path('modelversion/<uuid:pk>/', ModelVersionsUpdate.as_view(), name='model-version-update'), path('models/', ModelsList.as_view(), name='models'), + path('model/<uuid:pk>/versions/', ModelVersionsList.as_view(), name='model-versions'), # Image management path('image/', ImageCreate.as_view(), name='image-create'), diff --git a/arkindex/training/api.py b/arkindex/training/api.py index 788fbb7b6389aa1e3611b93159d773340b6d7824..a01b5a56e28367c8bce9570b48f9fe754c984820 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -1,21 +1,32 @@ from django.db.models import Q +from django.shortcuts import get_object_or_404 +from django.utils.functional import cached_property from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view from rest_framework.exceptions import PermissionDenied, ValidationError -from rest_framework.generics import CreateAPIView, ListCreateAPIView, UpdateAPIView +from rest_framework.generics import ListCreateAPIView, UpdateAPIView from arkindex.project.mixins import TrainingModelMixin from arkindex.project.permissions import IsVerified -from arkindex.training.models import Model, ModelVersion +from arkindex.training.models import Model, ModelVersion, ModelVersionState from arkindex.training.serializers import ( CreateModelErrorResponseSerializer, ModelSerializer, ModelVersionCreateSerializer, ModelVersionSerializer, ) +from arkindex.users.models import Role +from arkindex.users.utils import get_max_level @extend_schema(tags=['training']) @extend_schema_view( + get=extend_schema( + operation_id='ListModelVersions', + description=( + 'List available versions of a Machine Learning model.\n\n' + 'Requires a **guest** access to the model.' + ), + ), post=extend_schema( operation_id='CreateModelVersion', description=( @@ -27,24 +38,48 @@ from arkindex.training.serializers import ( }, ), ) -class ModelVersionsCreate(TrainingModelMixin, CreateAPIView): +class ModelVersionsList(TrainingModelMixin, ListCreateAPIView): permission_classes = (IsVerified, ) serializer_class = ModelVersionCreateSerializer queryset = Model.objects.none() def get_queryset(self): - return Model.objects.filter(id=self.kwargs['pk']) + return Model.objects.filter(pk=self.kwargs['pk']) + + def perform_create(self, serializer): + if not self.model_access_rights or self.model_access_rights < Role.Contributor.value: + raise PermissionDenied(detail='You need a Contributor access to the model to create a new version.') + serializer.save() + + @cached_property + def model(self): + return get_object_or_404(self.get_queryset(), id=self.kwargs['pk']) + + @cached_property + def model_access_rights(self): + return get_max_level(self.request.user, self.model) def get_serializer_context(self): context = super().get_serializer_context() - context['model'] = self.get_object() + context['model'] = self.model + context['model_rights'] = self.model_access_rights return context - def check_object_permissions(self, request, model): - super().check_object_permissions(request, model) + def filter_queryset(self, queryset): + filters = Q(model=self.kwargs['pk']) - if not self.has_write_access(model): - raise PermissionDenied(detail='You need a Contributor access to the model to create a new version.') + access_level = self.model_access_rights + if not access_level or access_level < Role.Guest.value: + raise PermissionDenied(detail='You need a guest access to the model to list its versions.') + + # Guest access allow only versions with a set tag and state = available + if access_level < Role.Contributor.value: + filters &= Q(tag__isnull=False, state=ModelVersionState.Available) + + # Count all versions for models with contributor+ access + queryset = ModelVersion.objects.filter(filters).order_by('-created') + + return super().filter_queryset(queryset) @extend_schema(tags=['training']) diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index 5b877dae225cebed61e07dce7a33a50005e18124..0e565bb1837974ee61f274643e7677024a6d5309 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -102,8 +102,8 @@ class ModelVersionCreateSerializer(ModelVersionSerializer): tag = serializers.CharField(allow_null=True, max_length=50, required=False, default=None) class Meta(ModelVersionSerializer.Meta): - fields = ModelVersionSerializer.Meta.fields + ('s3_put_url',) - read_only_fields = ModelVersionSerializer.Meta.read_only_fields + ('s3_put_url',) + fields = ModelVersionSerializer.Meta.fields + ('s3_put_url', 'created', ) + read_only_fields = ModelVersionSerializer.Meta.read_only_fields + ('s3_put_url', 'created', ) @extend_schema_field(serializers.CharField(allow_null=True)) def get_s3_put_url(self, obj): @@ -111,6 +111,11 @@ class ModelVersionCreateSerializer(ModelVersionSerializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # When used for CreateModelVersion, the API adds the provided model to the context + # The API adds the provided model to the context if self.context.get('model') is not None: self.fields['parent'].queryset = ModelVersion.objects.filter(model_id=self.context['model'].id) + # If user doesn't have a contributor access to the model, don't show s3_url and s3_put_url + access_level = self.context.get('model_rights') + if not access_level or access_level < Role.Contributor.value: + del self.fields['s3_url'] + del self.fields['s3_put_url'] diff --git a/arkindex/training/tests/test_model_api.py b/arkindex/training/tests/test_model_api.py index 94531da6487531492911876a0ebbe87036adef56..0d80cc689ad9a5ca24c9716781921702e1773c04 100644 --- a/arkindex/training/tests/test_model_api.py +++ b/arkindex/training/tests/test_model_api.py @@ -3,6 +3,7 @@ from unittest.mock import patch from django.urls import reverse +from django.utils import timezone from rest_framework import status from arkindex.project.tests import FixtureAPITestCase @@ -10,6 +11,10 @@ from arkindex.training.models import Model, ModelVersion, ModelVersionState from arkindex.users.models import Group, Right, Role, User +def _format_datetime(date): + return str(date.isoformat().replace('+00:00', 'Z')).replace(' ', 'T') + + class TestModelAPI(FixtureAPITestCase): """ Test model and model version api @@ -78,7 +83,7 @@ class TestModelAPI(FixtureAPITestCase): self.client.force_login(user) request = self.build_model_version_create_request() with self.assertNumQueries(2): - response = self.client.post(reverse("api:model-version-create", kwargs={"pk": str(self.model2.id)}), request, format='json') + response = self.client.post(reverse("api:model-versions", kwargs={"pk": str(self.model2.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."}) @@ -88,8 +93,8 @@ class TestModelAPI(FixtureAPITestCase): """ self.client.force_login(self.user1) request = self.build_model_version_create_request() - with self.assertNumQueries(6): - response = self.client.post(reverse("api:model-version-create", kwargs={"pk": str(self.model2.id)}), request, format='json') + with self.assertNumQueries(7): + response = self.client.post(reverse("api:model-versions", kwargs={"pk": str(self.model2.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the model to create a new version."}) @@ -101,13 +106,17 @@ class TestModelAPI(FixtureAPITestCase): s3_presigned_url_mock.return_value = 'http://s3/upload_put_url' self.client.force_login(self.user1) request = self.build_model_version_create_request() - with self.assertNumQueries(8): - response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json') + + fake_now = timezone.now() + # To mock the creation date + with patch('django.utils.timezone.now') as mock_now: + mock_now.return_value = fake_now + with self.assertNumQueries(8): + response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_201_CREATED) data = response.json() self.assertIn('id', data) df = ModelVersion.objects.get(id=data['id']) - self.maxDiff = None self.assertDictEqual( data, @@ -121,6 +130,7 @@ class TestModelAPI(FixtureAPITestCase): 'tag': None, 'size': request['size'], 'hash': request['hash'], + 'created': _format_datetime(fake_now), 's3_url': s3_presigned_url_mock.return_value, 's3_put_url': s3_presigned_url_mock.return_value } @@ -134,7 +144,7 @@ class TestModelAPI(FixtureAPITestCase): request = self.build_model_version_create_request() request['tag'] = '' with self.assertNumQueries(7): - response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json') + response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"tag": ["This field may not be blank."]}) @@ -151,8 +161,12 @@ class TestModelAPI(FixtureAPITestCase): 'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0', 'size': 8, } - with self.assertNumQueries(9): - response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json') + fake_now = timezone.now() + # To mock the creation date + with patch('django.utils.timezone.now') as mock_now: + mock_now.return_value = fake_now + with self.assertNumQueries(9): + response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_201_CREATED) data = response.json() self.assertIn('id', data) @@ -170,6 +184,7 @@ class TestModelAPI(FixtureAPITestCase): 'tag': request['tag'], 'size': request['size'], 'hash': request['hash'], + 'created': _format_datetime(fake_now), 's3_url': s3_presigned_url_mock.return_value, 's3_put_url': s3_presigned_url_mock.return_value } @@ -183,7 +198,7 @@ class TestModelAPI(FixtureAPITestCase): request = self.build_model_version_create_request() request['tag'] = self.model_version2.tag with self.assertNumQueries(8): - response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json') + response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]}) @@ -195,7 +210,7 @@ class TestModelAPI(FixtureAPITestCase): self.client.force_login(self.user1) request = {'tag': 'production', 'hash': self.model_version5.hash, 'size': 32} with self.assertNumQueries(7): - response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json') + response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"hash": ["A version for this model with this hash already exists."]}) @@ -207,7 +222,7 @@ class TestModelAPI(FixtureAPITestCase): self.client.force_login(self.user1) request = {'tag': 'production', 'hash': self.model_version2.hash, 'size': 32} with self.assertNumQueries(7): - response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request, format='json') + response = self.client.post(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)}), request, format='json') self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual( response.json()['hash'], @@ -219,6 +234,7 @@ class TestModelAPI(FixtureAPITestCase): 'state': self.model_version2.state.value, 'configuration': self.model_version2.configuration, 'tag': self.model_version2.tag, + 'created': _format_datetime(self.model_version2.created), 'size': str(self.model_version2.size), 'hash': self.model_version2.hash, 's3_url': self.model_version2.s3_url, @@ -574,6 +590,7 @@ class TestModelAPI(FixtureAPITestCase): response = self.client.get(reverse('api:models')) self.assertEqual(response.status_code, status.HTTP_200_OK) models = response.json()['results'] + self.assertEqual(len(models), 2) self.assertListEqual([model['id'] for model in models], [str(self.model1.id), str(self.model2.id)]) @@ -588,3 +605,38 @@ class TestModelAPI(FixtureAPITestCase): self.assertEqual(len(models), 1) self.assertListEqual([model['id'] for model in models], [str(self.model2.id)]) + + def test_list_model_versions_requires_logged_in(self): + """To list a model's versions, you need to be logged in. + """ + response = self.client.get(reverse('api:model-versions', kwargs={"pk": str(self.model1.id)})) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "Authentication credentials were not provided."}) + + def test_list_model_versions_requires_guest_access(self): + """To view a model's version, you need guest access on the model + """ + self.client.force_login(self.user1) + with self.assertNumQueries(6): + response = self.client.get(reverse('api:model-versions', kwargs={"pk": str(self.model2.id)})) + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": 'You need a guest access to the model to list its versions.'}) + + def test_list_model_versions_low_access(self): + """With only guest access rights on a model, you only see the available versions with a set tag + """ + self.client.force_login(self.user3) + with self.assertNumQueries(8): + response = self.client.get(reverse('api:model-versions', kwargs={"pk": str(self.model2.id)})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertListEqual([version["id"] for version in response.json()['results']], [str(self.model_version3.id)]) + + def test_list_model_versions_full_access(self): + """With contributor access rights, you see every versions. + """ + self.client.force_login(self.user2) + with self.assertNumQueries(8): + response = self.client.get(reverse('api:model-versions', kwargs={"pk": str(self.model2.id)})) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertListEqual([version["id"] for version in response.json()['results']], [str(self.model_version4.id), str(self.model_version3.id)])