Skip to content
Snippets Groups Projects
Commit 93923a58 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Erwan Rouchet
Browse files

New endpoint to create Models

parent ea811d5a
No related branches found
No related tags found
1 merge request!1642New endpoint to create Models
......@@ -88,7 +88,7 @@ from arkindex.documents.api.ml import (
from arkindex.documents.api.search import CorpusSearch
from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve
from arkindex.project.openapi import OpenApiSchemaView
from arkindex.training.api import ModelVersionsCreate, ModelVersionsUpdate
from arkindex.training.api import ModelCreate, ModelVersionsCreate, ModelVersionsUpdate
from arkindex.users.api import (
CredentialsList,
CredentialsRetrieve,
......@@ -239,6 +239,7 @@ api = [
# ML models training
path('model/<uuid:pk>/versions/', ModelVersionsCreate.as_view(), name='model-version-create'),
path('modelversion/<uuid:pk>/', ModelVersionsUpdate.as_view(), name='model-version-update'),
path('model/', ModelCreate.as_view(), name='model-create'),
# Image management
path('image/', ImageCreate.as_view(), name='image-create'),
......
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework.exceptions import PermissionDenied
from rest_framework.exceptions import PermissionDenied, ValidationError
from rest_framework.generics import CreateAPIView, UpdateAPIView
from arkindex.project.mixins import TrainingModelMixin
from arkindex.project.permissions import IsVerified
from arkindex.training.models import Model, ModelVersion
from arkindex.training.serializers import ModelVersionCreateSerializer, ModelVersionSerializer
from arkindex.training.serializers import (
CreateModelErrorResponseSerializer,
ModelSerializer,
ModelVersionCreateSerializer,
ModelVersionSerializer,
)
@extend_schema(tags=['training'])
......@@ -72,3 +77,34 @@ class ModelVersionsUpdate(TrainingModelMixin, UpdateAPIView):
super().check_object_permissions(request, model_version.model)
if not self.has_write_access(model_version.model):
raise PermissionDenied(detail='You need a Contributor access to the model to edit this version.')
@extend_schema(tags=['training'])
@extend_schema_view(
post=extend_schema(
description=(
'Create a new Machine Learning model.'
),
responses={
200: ModelSerializer,
400: CreateModelErrorResponseSerializer,
403: None
},
),
)
class ModelCreate(TrainingModelMixin, CreateAPIView):
permission_classes = (IsVerified, )
serializer_class = ModelSerializer
def perform_create(self, serializer):
model_name = serializer.validated_data.get('name')
existing_model = Model.objects.using('default').filter(name=model_name).first()
if existing_model:
if self.has_read_access(existing_model):
raise ValidationError({
'id': str(existing_model.id),
'name': 'A model with this name already exists',
})
else:
raise PermissionDenied()
return serializer.save()
......@@ -3,12 +3,13 @@ import re
from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers
from rest_framework.serializers import ValidationError
from rest_framework.exceptions import ValidationError
from rest_framework.validators import UniqueTogetherValidator
from arkindex.project.mixins import TrainingModelMixin
from arkindex.project.serializer_fields import EnumField
from arkindex.training.models import ModelVersion, ModelVersionState
from arkindex.training.models import Model, ModelVersion, ModelVersionState
from arkindex.users.models import Role
def _model_from_context(serializer_field):
......@@ -18,6 +19,31 @@ def _model_from_context(serializer_field):
_model_from_context.requires_context = True
class ModelSerializer(TrainingModelMixin, serializers.ModelSerializer):
# Actually define the field to avoid the field-level automatically generated UniqueValidator
name = serializers.CharField(max_length=100)
class Meta:
model = Model
fields = ('id', 'created', 'updated', 'name', 'description')
def create(self, validated_data):
instance = super().create(validated_data)
# Create admin membership for creator user
instance.memberships.create(
user=self.context['request'].user,
level=Role.Admin.value
)
return instance
class CreateModelErrorResponseSerializer(serializers.Serializer):
id = serializers.UUIDField(required=False, help_text="UUID of an existing model, if the error comes from a duplicate name.")
name = serializers.CharField(required=False, help_text="Name of an existing model, if the error comes from a duplicate name.")
detail = serializers.CharField(required=False, help_text="A generic error message when an error occurs outside of a specific field.")
class ModelVersionSerializer(TrainingModelMixin, serializers.ModelSerializer):
model = serializers.HiddenField(default=_model_from_context)
description = serializers.CharField(allow_blank=True, style={'base_template': 'textarea.html'})
......
......@@ -7,7 +7,7 @@ from rest_framework import status
from arkindex.project.tests import FixtureAPITestCase
from arkindex.training.models import Model, ModelVersion, ModelVersionState
from arkindex.users.models import Group, Right, User
from arkindex.users.models import Group, Right, Role, User
class TestModelVersion(FixtureAPITestCase):
......@@ -404,3 +404,85 @@ class TestModelVersion(FixtureAPITestCase):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"state": ['MD5 hashes do not match']})
def test_create_model_requires_verified(self):
"""Model creation requires being verified
"""
user = User.objects.create(display_name="Not Verified", verified_email=False)
self.client.force_login(user)
response = self.client.post(reverse('api:model-create'), {"name": "The Best Classifier"})
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."})
def test_create_model_requires_logged_in(self):
"""Model creation requires being logged in
"""
response = self.client.post(reverse('api:model-create'), {"name": "The Best Classifier"})
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "Authentication credentials were not provided."})
def test_create_model_requires_name(self):
"""To create a model you must provide a name
"""
self.client.force_login(self.user1)
with self.assertNumQueries(2):
response = self.client.post(reverse('api:model-create'))
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"name": ["This field is required."]})
def test_create_model_requires_name_not_blank(self):
"""To create a model you must provide a name
"""
self.client.force_login(self.user1)
with self.assertNumQueries(2):
response = self.client.post(reverse('api:model-create'), {"name": ''})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"name": ["This field may not be blank."]})
def test_create_model_new_name(self):
"""Create a model with a name no other model has and a description. The creator also gets admin rights on the newly created model.
"""
self.client.force_login(self.user1)
request = {
"name": 'The Best Model Ever',
"description": "This is actually the best model ever."
}
with self.assertNumQueries(6):
response = self.client.post(reverse('api:model-create'), request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
model = Model.objects.get(id=response.json()['id'])
self.assertEqual(model.name, request.get("name"))
self.assertEqual(model.description, request.get("description"))
# Check that user has admin rights on the model
user_right = model.memberships.get(user=self.user1)
self.assertEqual(user_right.level, Role.Admin.value)
def test_create_model_name_taken_with_access(self):
"""Raises a 400 with the model_id when creating a model with a name that is already used for another model
but with guest access rights on this model
"""
self.client.force_login(self.user1)
request = {
"name": self.model1.name,
}
with self.assertNumQueries(6):
response = self.client.post(reverse('api:model-create'), request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"id": str(self.model1.id), "name": 'A model with this name already exists'})
def test_create_model_name_taken_no_access(self):
"""Raises a 403 with no additional information when creating a model with a name that is already used for another model
but without access rights on this model
"""
self.client.force_login(self.user3)
request = {
"name": self.model1.name,
}
with self.assertNumQueries(6):
response = self.client.post(reverse('api:model-create'), request)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment