From 93923a58ce29bdaacf6b56ec8fe54a34988a8593 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Tue, 29 Mar 2022 15:38:40 +0000
Subject: [PATCH] New endpoint to create Models

---
 arkindex/project/api_v1.py                    |  3 +-
 arkindex/training/api.py                      | 40 ++++++++-
 arkindex/training/serializers.py              | 30 ++++++-
 .../training/tests/test_model_versions.py     | 84 ++++++++++++++++++-
 4 files changed, 151 insertions(+), 6 deletions(-)

diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py
index 921246d2d7..20700ca4cb 100644
--- a/arkindex/project/api_v1.py
+++ b/arkindex/project/api_v1.py
@@ -88,7 +88,7 @@ from arkindex.documents.api.ml import (
 from arkindex.documents.api.search import CorpusSearch
 from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve
 from arkindex.project.openapi import OpenApiSchemaView
-from arkindex.training.api import ModelVersionsCreate, ModelVersionsUpdate
+from arkindex.training.api import ModelCreate, ModelVersionsCreate, ModelVersionsUpdate
 from arkindex.users.api import (
     CredentialsList,
     CredentialsRetrieve,
@@ -239,6 +239,7 @@ api = [
     # ML models training
     path('model/<uuid:pk>/versions/', ModelVersionsCreate.as_view(), name='model-version-create'),
     path('modelversion/<uuid:pk>/', ModelVersionsUpdate.as_view(), name='model-version-update'),
+    path('model/', ModelCreate.as_view(), name='model-create'),
 
     # Image management
     path('image/', ImageCreate.as_view(), name='image-create'),
diff --git a/arkindex/training/api.py b/arkindex/training/api.py
index aa9e9b7d0b..fb774cb73a 100644
--- a/arkindex/training/api.py
+++ b/arkindex/training/api.py
@@ -1,11 +1,16 @@
 from drf_spectacular.utils import extend_schema, extend_schema_view
-from rest_framework.exceptions import PermissionDenied
+from rest_framework.exceptions import PermissionDenied, ValidationError
 from rest_framework.generics import CreateAPIView, UpdateAPIView
 
 from arkindex.project.mixins import TrainingModelMixin
 from arkindex.project.permissions import IsVerified
 from arkindex.training.models import Model, ModelVersion
-from arkindex.training.serializers import ModelVersionCreateSerializer, ModelVersionSerializer
+from arkindex.training.serializers import (
+    CreateModelErrorResponseSerializer,
+    ModelSerializer,
+    ModelVersionCreateSerializer,
+    ModelVersionSerializer,
+)
 
 
 @extend_schema(tags=['training'])
@@ -72,3 +77,34 @@ class ModelVersionsUpdate(TrainingModelMixin, UpdateAPIView):
         super().check_object_permissions(request, model_version.model)
         if not self.has_write_access(model_version.model):
             raise PermissionDenied(detail='You need a Contributor access to the model to edit this version.')
+
+
+@extend_schema(tags=['training'])
+@extend_schema_view(
+    post=extend_schema(
+        description=(
+            'Create a new Machine Learning model.'
+        ),
+        responses={
+            200: ModelSerializer,
+            400: CreateModelErrorResponseSerializer,
+            403: None
+        },
+    ),
+)
+class ModelCreate(TrainingModelMixin, CreateAPIView):
+    permission_classes = (IsVerified, )
+    serializer_class = ModelSerializer
+
+    def perform_create(self, serializer):
+        model_name = serializer.validated_data.get('name')
+        existing_model = Model.objects.using('default').filter(name=model_name).first()
+        if existing_model:
+            if self.has_read_access(existing_model):
+                raise ValidationError({
+                    'id': str(existing_model.id),
+                    'name': 'A model with this name already exists',
+                })
+            else:
+                raise PermissionDenied()
+        return serializer.save()
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
index d3edc8ad51..2626c7b4ee 100644
--- a/arkindex/training/serializers.py
+++ b/arkindex/training/serializers.py
@@ -3,12 +3,13 @@ import re
 
 from drf_spectacular.utils import extend_schema_field
 from rest_framework import serializers
-from rest_framework.serializers import ValidationError
+from rest_framework.exceptions import ValidationError
 from rest_framework.validators import UniqueTogetherValidator
 
 from arkindex.project.mixins import TrainingModelMixin
 from arkindex.project.serializer_fields import EnumField
-from arkindex.training.models import ModelVersion, ModelVersionState
+from arkindex.training.models import Model, ModelVersion, ModelVersionState
+from arkindex.users.models import Role
 
 
 def _model_from_context(serializer_field):
@@ -18,6 +19,31 @@ def _model_from_context(serializer_field):
 _model_from_context.requires_context = True
 
 
+class ModelSerializer(TrainingModelMixin, serializers.ModelSerializer):
+    # Actually define the field to avoid the field-level automatically generated UniqueValidator
+    name = serializers.CharField(max_length=100)
+
+    class Meta:
+        model = Model
+        fields = ('id', 'created', 'updated', 'name', 'description')
+
+    def create(self, validated_data):
+        instance = super().create(validated_data)
+
+        # Create admin membership for creator user
+        instance.memberships.create(
+            user=self.context['request'].user,
+            level=Role.Admin.value
+        )
+        return instance
+
+
+class CreateModelErrorResponseSerializer(serializers.Serializer):
+    id = serializers.UUIDField(required=False, help_text="UUID of an existing model, if the error comes from a duplicate name.")
+    name = serializers.CharField(required=False, help_text="Name of an existing model, if the error comes from a duplicate name.")
+    detail = serializers.CharField(required=False, help_text="A generic error message when an error occurs outside of a specific field.")
+
+
 class ModelVersionSerializer(TrainingModelMixin, serializers.ModelSerializer):
     model = serializers.HiddenField(default=_model_from_context)
     description = serializers.CharField(allow_blank=True, style={'base_template': 'textarea.html'})
diff --git a/arkindex/training/tests/test_model_versions.py b/arkindex/training/tests/test_model_versions.py
index 00f7727295..96b3cb1d46 100644
--- a/arkindex/training/tests/test_model_versions.py
+++ b/arkindex/training/tests/test_model_versions.py
@@ -7,7 +7,7 @@ from rest_framework import status
 
 from arkindex.project.tests import FixtureAPITestCase
 from arkindex.training.models import Model, ModelVersion, ModelVersionState
-from arkindex.users.models import Group, Right, User
+from arkindex.users.models import Group, Right, Role, User
 
 
 class TestModelVersion(FixtureAPITestCase):
@@ -404,3 +404,85 @@ class TestModelVersion(FixtureAPITestCase):
 
         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
         self.assertDictEqual(response.json(), {"state": ['MD5 hashes do not match']})
+
+    def test_create_model_requires_verified(self):
+        """Model creation requires being verified
+        """
+        user = User.objects.create(display_name="Not Verified", verified_email=False)
+        self.client.force_login(user)
+        response = self.client.post(reverse('api:model-create'), {"name": "The Best Classifier"})
+        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+        self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."})
+
+    def test_create_model_requires_logged_in(self):
+        """Model creation requires being logged in
+        """
+        response = self.client.post(reverse('api:model-create'), {"name": "The Best Classifier"})
+        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+        self.assertDictEqual(response.json(), {"detail": "Authentication credentials were not provided."})
+
+    def test_create_model_requires_name(self):
+        """To create a model you must provide a name
+        """
+        self.client.force_login(self.user1)
+        with self.assertNumQueries(2):
+            response = self.client.post(reverse('api:model-create'))
+
+        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {"name": ["This field is required."]})
+
+    def test_create_model_requires_name_not_blank(self):
+        """To create a model you must provide a name
+        """
+        self.client.force_login(self.user1)
+        with self.assertNumQueries(2):
+            response = self.client.post(reverse('api:model-create'), {"name": ''})
+
+        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {"name": ["This field may not be blank."]})
+
+    def test_create_model_new_name(self):
+        """Create a model with a name no other model has and a description. The creator also gets admin rights on the newly created model.
+        """
+        self.client.force_login(self.user1)
+        request = {
+            "name": 'The Best Model Ever',
+            "description": "This is actually the best model ever."
+        }
+        with self.assertNumQueries(6):
+            response = self.client.post(reverse('api:model-create'), request)
+
+        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+        model = Model.objects.get(id=response.json()['id'])
+        self.assertEqual(model.name, request.get("name"))
+        self.assertEqual(model.description, request.get("description"))
+
+        # Check that user has admin rights on the model
+        user_right = model.memberships.get(user=self.user1)
+        self.assertEqual(user_right.level, Role.Admin.value)
+
+    def test_create_model_name_taken_with_access(self):
+        """Raises a 400 with the model_id when creating a model with a name that is already used for another model
+        but with guest access rights on this model
+        """
+        self.client.force_login(self.user1)
+        request = {
+            "name": self.model1.name,
+        }
+        with self.assertNumQueries(6):
+            response = self.client.post(reverse('api:model-create'), request)
+        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {"id": str(self.model1.id), "name": 'A model with this name already exists'})
+
+    def test_create_model_name_taken_no_access(self):
+        """Raises a 403 with no additional information when creating a model with a name that is already used for another model
+        but without access rights on this model
+        """
+        self.client.force_login(self.user3)
+        request = {
+            "name": self.model1.name,
+        }
+        with self.assertNumQueries(6):
+            response = self.client.post(reverse('api:model-create'), request)
+        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+        self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."})
-- 
GitLab