From 93923a58ce29bdaacf6b56ec8fe54a34988a8593 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Tue, 29 Mar 2022 15:38:40 +0000 Subject: [PATCH] New endpoint to create Models --- arkindex/project/api_v1.py | 3 +- arkindex/training/api.py | 40 ++++++++- arkindex/training/serializers.py | 30 ++++++- .../training/tests/test_model_versions.py | 84 ++++++++++++++++++- 4 files changed, 151 insertions(+), 6 deletions(-) diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index 921246d2d7..20700ca4cb 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -88,7 +88,7 @@ from arkindex.documents.api.ml import ( from arkindex.documents.api.search import CorpusSearch from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve from arkindex.project.openapi import OpenApiSchemaView -from arkindex.training.api import ModelVersionsCreate, ModelVersionsUpdate +from arkindex.training.api import ModelCreate, ModelVersionsCreate, ModelVersionsUpdate from arkindex.users.api import ( CredentialsList, CredentialsRetrieve, @@ -239,6 +239,7 @@ api = [ # ML models training path('model/<uuid:pk>/versions/', ModelVersionsCreate.as_view(), name='model-version-create'), path('modelversion/<uuid:pk>/', ModelVersionsUpdate.as_view(), name='model-version-update'), + path('model/', ModelCreate.as_view(), name='model-create'), # Image management path('image/', ImageCreate.as_view(), name='image-create'), diff --git a/arkindex/training/api.py b/arkindex/training/api.py index aa9e9b7d0b..fb774cb73a 100644 --- a/arkindex/training/api.py +++ b/arkindex/training/api.py @@ -1,11 +1,16 @@ from drf_spectacular.utils import extend_schema, extend_schema_view -from rest_framework.exceptions import PermissionDenied +from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.generics import CreateAPIView, UpdateAPIView from arkindex.project.mixins import TrainingModelMixin from arkindex.project.permissions import IsVerified from arkindex.training.models import Model, ModelVersion -from arkindex.training.serializers import ModelVersionCreateSerializer, ModelVersionSerializer +from arkindex.training.serializers import ( + CreateModelErrorResponseSerializer, + ModelSerializer, + ModelVersionCreateSerializer, + ModelVersionSerializer, +) @extend_schema(tags=['training']) @@ -72,3 +77,34 @@ class ModelVersionsUpdate(TrainingModelMixin, UpdateAPIView): super().check_object_permissions(request, model_version.model) if not self.has_write_access(model_version.model): raise PermissionDenied(detail='You need a Contributor access to the model to edit this version.') + + +@extend_schema(tags=['training']) +@extend_schema_view( + post=extend_schema( + description=( + 'Create a new Machine Learning model.' + ), + responses={ + 200: ModelSerializer, + 400: CreateModelErrorResponseSerializer, + 403: None + }, + ), +) +class ModelCreate(TrainingModelMixin, CreateAPIView): + permission_classes = (IsVerified, ) + serializer_class = ModelSerializer + + def perform_create(self, serializer): + model_name = serializer.validated_data.get('name') + existing_model = Model.objects.using('default').filter(name=model_name).first() + if existing_model: + if self.has_read_access(existing_model): + raise ValidationError({ + 'id': str(existing_model.id), + 'name': 'A model with this name already exists', + }) + else: + raise PermissionDenied() + return serializer.save() diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py index d3edc8ad51..2626c7b4ee 100644 --- a/arkindex/training/serializers.py +++ b/arkindex/training/serializers.py @@ -3,12 +3,13 @@ import re from drf_spectacular.utils import extend_schema_field from rest_framework import serializers -from rest_framework.serializers import ValidationError +from rest_framework.exceptions import ValidationError from rest_framework.validators import UniqueTogetherValidator from arkindex.project.mixins import TrainingModelMixin from arkindex.project.serializer_fields import EnumField -from arkindex.training.models import ModelVersion, ModelVersionState +from arkindex.training.models import Model, ModelVersion, ModelVersionState +from arkindex.users.models import Role def _model_from_context(serializer_field): @@ -18,6 +19,31 @@ def _model_from_context(serializer_field): _model_from_context.requires_context = True +class ModelSerializer(TrainingModelMixin, serializers.ModelSerializer): + # Actually define the field to avoid the field-level automatically generated UniqueValidator + name = serializers.CharField(max_length=100) + + class Meta: + model = Model + fields = ('id', 'created', 'updated', 'name', 'description') + + def create(self, validated_data): + instance = super().create(validated_data) + + # Create admin membership for creator user + instance.memberships.create( + user=self.context['request'].user, + level=Role.Admin.value + ) + return instance + + +class CreateModelErrorResponseSerializer(serializers.Serializer): + id = serializers.UUIDField(required=False, help_text="UUID of an existing model, if the error comes from a duplicate name.") + name = serializers.CharField(required=False, help_text="Name of an existing model, if the error comes from a duplicate name.") + detail = serializers.CharField(required=False, help_text="A generic error message when an error occurs outside of a specific field.") + + class ModelVersionSerializer(TrainingModelMixin, serializers.ModelSerializer): model = serializers.HiddenField(default=_model_from_context) description = serializers.CharField(allow_blank=True, style={'base_template': 'textarea.html'}) diff --git a/arkindex/training/tests/test_model_versions.py b/arkindex/training/tests/test_model_versions.py index 00f7727295..96b3cb1d46 100644 --- a/arkindex/training/tests/test_model_versions.py +++ b/arkindex/training/tests/test_model_versions.py @@ -7,7 +7,7 @@ from rest_framework import status from arkindex.project.tests import FixtureAPITestCase from arkindex.training.models import Model, ModelVersion, ModelVersionState -from arkindex.users.models import Group, Right, User +from arkindex.users.models import Group, Right, Role, User class TestModelVersion(FixtureAPITestCase): @@ -404,3 +404,85 @@ class TestModelVersion(FixtureAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertDictEqual(response.json(), {"state": ['MD5 hashes do not match']}) + + def test_create_model_requires_verified(self): + """Model creation requires being verified + """ + user = User.objects.create(display_name="Not Verified", verified_email=False) + self.client.force_login(user) + response = self.client.post(reverse('api:model-create'), {"name": "The Best Classifier"}) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."}) + + def test_create_model_requires_logged_in(self): + """Model creation requires being logged in + """ + response = self.client.post(reverse('api:model-create'), {"name": "The Best Classifier"}) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "Authentication credentials were not provided."}) + + def test_create_model_requires_name(self): + """To create a model you must provide a name + """ + self.client.force_login(self.user1) + with self.assertNumQueries(2): + response = self.client.post(reverse('api:model-create')) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"name": ["This field is required."]}) + + def test_create_model_requires_name_not_blank(self): + """To create a model you must provide a name + """ + self.client.force_login(self.user1) + with self.assertNumQueries(2): + response = self.client.post(reverse('api:model-create'), {"name": ''}) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"name": ["This field may not be blank."]}) + + def test_create_model_new_name(self): + """Create a model with a name no other model has and a description. The creator also gets admin rights on the newly created model. + """ + self.client.force_login(self.user1) + request = { + "name": 'The Best Model Ever', + "description": "This is actually the best model ever." + } + with self.assertNumQueries(6): + response = self.client.post(reverse('api:model-create'), request) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + model = Model.objects.get(id=response.json()['id']) + self.assertEqual(model.name, request.get("name")) + self.assertEqual(model.description, request.get("description")) + + # Check that user has admin rights on the model + user_right = model.memberships.get(user=self.user1) + self.assertEqual(user_right.level, Role.Admin.value) + + def test_create_model_name_taken_with_access(self): + """Raises a 400 with the model_id when creating a model with a name that is already used for another model + but with guest access rights on this model + """ + self.client.force_login(self.user1) + request = { + "name": self.model1.name, + } + with self.assertNumQueries(6): + response = self.client.post(reverse('api:model-create'), request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"id": str(self.model1.id), "name": 'A model with this name already exists'}) + + def test_create_model_name_taken_no_access(self): + """Raises a 403 with no additional information when creating a model with a name that is already used for another model + but without access rights on this model + """ + self.client.force_login(self.user3) + request = { + "name": self.model1.name, + } + with self.assertNumQueries(6): + response = self.client.post(reverse('api:model-create'), request) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."}) -- GitLab