Skip to content
Snippets Groups Projects
Commit 7e357eeb authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Erwan Rouchet
Browse files

Create model version endpoint

parent 88e0feb2
No related branches found
No related tags found
1 merge request!1634Create model version endpoint
......@@ -88,6 +88,7 @@ from arkindex.documents.api.ml import (
from arkindex.documents.api.search import CorpusSearch
from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve
from arkindex.project.openapi import OpenApiSchemaView
from arkindex.training.api import ModelVersionsCreate
from arkindex.users.api import (
CredentialsList,
CredentialsRetrieve,
......@@ -235,6 +236,9 @@ api = [
path('process/<uuid:pk>/template/', CreateProcessTemplate.as_view(), name='create-process-template'),
path('process/<uuid:pk>/apply/', ApplyProcessTemplate.as_view(), name='apply-process-template'),
# ML models training
path('model/<uuid:pk>/versions/', ModelVersionsCreate.as_view(), name='model-version-create'),
# Image management
path('image/', ImageCreate.as_view(), name='image-create'),
path('image/iiif/url/', IIIFURLCreate.as_view(), name='iiif-url-create'),
......
......@@ -195,6 +195,7 @@ def get_settings_parser(base_dir):
s3_parser.add_option('thumbnails_bucket', type=str, default='thumbnails')
s3_parser.add_option('staging_bucket', type=str, default='staging')
s3_parser.add_option('export_bucket', type=str, default='export')
s3_parser.add_option('training_bucket', type=str, default='training')
s3_parser.add_option('ponos_logs_bucket', type=str, default='ponos-logs')
s3_parser.add_option('ponos_artifacts_bucket', type=str, default='ponos-artifacts')
......
......@@ -271,6 +271,10 @@ SPECTACULAR_SETTINGS = {
'name': 'ml',
'description': 'Machine Learning tools and results',
},
{
'name': 'training',
'description': 'Machine Learning models training',
},
{'name': 'oauth'},
{'name': 'ponos'},
{'name': 'repos'},
......@@ -488,6 +492,7 @@ PONOS_S3_ARTIFACTS_BUCKET = conf['s3']['ponos_artifacts_bucket']
AWS_THUMBNAIL_BUCKET = conf['s3']['thumbnails_bucket']
AWS_STAGING_BUCKET = conf['s3']['staging_bucket']
AWS_EXPORT_BUCKET = conf['s3']['export_bucket']
AWS_TRAINING_BUCKET = conf['s3']['training_bucket']
# Ponos integration
_ponos_env = {
......
......@@ -78,6 +78,7 @@ s3:
secret_access_key: null
staging_bucket: staging
thumbnails_bucket: thumbnails
training_bucket: training
secret_key: jf0w^y&ml(caax8f&a1mub)(js9(l5mhbbhosz3gi+m01ex+lo
sentry:
dsn: null
......
......@@ -93,6 +93,7 @@ s3:
secret_access_key: hunter2
staging_bucket: dropboxbutworse
thumbnails_bucket: toenails
training_bucket: nachoneko
secret_key: abcdef
sentry:
dsn: https://nowhere
......
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework.exceptions import PermissionDenied
from rest_framework.generics import CreateAPIView
from arkindex.project.mixins import TrainingModelMixin
from arkindex.project.permissions import IsVerified
from arkindex.training.models import Model
from arkindex.training.serializers import ModelVersionCreateSerializer
@extend_schema(tags=['training'])
@extend_schema_view(
post=extend_schema(
operation_id='CreateModelVersion',
description=(
'Create a new version for a Machine Learning model.\n\n'
'Requires a **contributor** access to the model.'
),
responses={
200: ModelVersionCreateSerializer, 403: None
},
),
)
class ModelVersionsCreate(TrainingModelMixin, CreateAPIView):
permission_classes = (IsVerified, )
serializer_class = ModelVersionCreateSerializer
queryset = Model.objects.none()
def get_queryset(self):
return Model.objects.filter(id=self.kwargs['pk'])
def get_serializer_context(self):
context = super().get_serializer_context()
context['model'] = self.get_object()
return context
def check_object_permissions(self, request, model):
super().check_object_permissions(request, model)
if not self.has_write_access(model):
raise PermissionDenied(detail='You need a Contributor access to the model to create a new version.')
# Generated by Django 4.0.2 on 2022-03-23 14:06
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('training', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='modelversion',
name='tag',
field=models.CharField(blank=True, default=None, max_length=50, null=True),
),
]
from django.conf import settings
from django.contrib.contenttypes.fields import GenericRelation
from django.db import models
from django.utils.functional import cached_property
from enumfields import Enum, EnumField
from arkindex.project.aws import S3FileMixin
from arkindex.project.fields import MD5HashField
from arkindex.project.models import IndexableModel
......@@ -36,7 +39,7 @@ class ModelVersionState(Enum):
Error = 'error'
class ModelVersion(IndexableModel):
class ModelVersion(S3FileMixin, IndexableModel):
"""
A specific Model version
"""
......@@ -46,7 +49,7 @@ class ModelVersion(IndexableModel):
description = models.TextField(default="")
tag = models.CharField(null=True, max_length=50, blank=True)
tag = models.CharField(null=True, max_length=50, blank=True, default=None)
state = EnumField(ModelVersionState, default=ModelVersionState.Created)
......@@ -63,3 +66,8 @@ class ModelVersion(IndexableModel):
unique_together = (
('model', 'tag'),
)
s3_bucket = settings.AWS_TRAINING_BUCKET
@cached_property
def s3_key(self):
return f'{self.id}.zst'
import re
from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers
from rest_framework.validators import UniqueTogetherValidator
from arkindex.project.mixins import TrainingModelMixin
from arkindex.project.serializer_fields import EnumField
from arkindex.training.models import ModelVersion, ModelVersionState
def _model_from_context(serializer_field):
return serializer_field.context.get('model')
_model_from_context.requires_context = True
class ModelVersionSerializer(TrainingModelMixin, serializers.ModelSerializer):
model = serializers.HiddenField(default=_model_from_context)
tag = serializers.CharField(allow_blank=True, allow_null=True, max_length=50, required=False, default=None)
parent = serializers.PrimaryKeyRelatedField(queryset=ModelVersion.objects.none(), default=None)
state = EnumField(ModelVersionState, read_only=True)
s3_url = serializers.SerializerMethodField()
class Meta:
model = ModelVersion
fields = ('id', 'model', 'model_id', 'parent', 'description', 'tag', 'hash', 'state', 'size', 'configuration', 's3_url')
read_only_fields = ('id', 'model_id', 'parent', 'state', 'hash', 'size', 's3_url')
validators = [
UniqueTogetherValidator(
queryset=ModelVersion.objects.filter(tag__isnull=False),
fields=['model', 'tag'],
message='A version for this model with this tag already exists.'
)
]
@extend_schema_field(serializers.CharField(allow_null=True))
def get_s3_url(self, obj):
return obj.s3_url
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# When used for CreateModelVersion, the API adds the provided model to the context
if self.context.get('model') is not None:
self.fields['parent'].queryset = ModelVersion.objects.filter(model_id=self.context['model'].id)
class ModelVersionCreateSerializer(ModelVersionSerializer):
"""
Create a new Model Version by providing the hash and the size of the archive
"""
hash = serializers.RegexField(re.compile(r'[0-9A-Fa-f]{32}'), min_length=32, max_length=32)
size = serializers.IntegerField(min_value=0)
s3_put_url = serializers.SerializerMethodField()
class Meta(ModelVersionSerializer.Meta):
fields = ModelVersionSerializer.Meta.fields + ('s3_put_url',)
read_only_fields = ModelVersionSerializer.Meta.read_only_fields + ('s3_put_url',)
@extend_schema_field(serializers.CharField(allow_null=True))
def get_s3_put_url(self, obj):
return obj.s3_put_url
from unittest.mock import patch
from django.urls import reverse
from rest_framework import status
from arkindex.project.tests import FixtureAPITestCase
from arkindex.training.models import Model, ModelVersion, ModelVersionState
from arkindex.users.models import Group, Right, User
class TestModelVersion(FixtureAPITestCase):
"""
Test model versions
"""
@classmethod
def setUpTestData(cls):
r"""We use a simple rights configuration for those tests
User1 User2 User3
| / /
| 100 /
| | /
100 Group1 10
| / \ /
| 50 100 /
| / \ /
Model1 Model2
"""
super().setUpTestData()
# Create there models
cls.model1 = Model.objects.create(name="First Model")
cls.model2 = Model.objects.create(name="Second Model")
cls.model3 = Model.objects.create(name="Third Model")
# Create three users with different access rights on each model
cls.user1 = User.objects.create(email='user1@test.test', display_name='User 1', verified_email=True)
cls.user2 = User.objects.create(email='user2@test.test', display_name='User 2', verified_email=True)
cls.user3 = User.objects.create(email='user3@test.test', display_name='User 3', verified_email=True)
# Create the group
cls.group1 = Group.objects.create(name="Group1")
# Access rights
Right.objects.bulk_create([
Right(user=cls.user1, content_object=cls.model1, level=100),
Right(user=cls.user2, content_object=cls.group1, level=100),
Right(user=cls.user3, content_object=cls.model2, level=10),
Right(group=cls.group1, content_object=cls.model1, level=50),
Right(group=cls.group1, content_object=cls.model2, level=100),
])
def test_create_model_version_requires_verified(self):
user = User.objects.create(email='not_verified@mail.com', display_name='Not Verified', verified_email=False)
self.client.force_login(user)
with self.assertNumQueries(2):
response = self.client.post(reverse("api:model-version-create", kwargs={"pk": str(self.model2.id)}), {'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0', 'size': 8})
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."})
def test_model_version_requires_contributor(self):
"""
Can't create model version as guest
"""
self.client.force_login(self.user1)
with self.assertNumQueries(6):
response = self.client.post(reverse("api:model-version-create", kwargs={"pk": str(self.model2.id)}), {'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0', 'size': 8})
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the model to create a new version."})
@patch('arkindex.project.aws.s3.meta.client.generate_presigned_url')
def test_model_version_creation_no_tag(self, s3_presigned_url_mock):
"""
Creates a new model version without setting a tag
"""
self.client.force_login(self.user1)
s3_presigned_url_mock.return_value = 'http://s3/upload_put_url'
request = {
'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0',
'size': 8
}
with self.assertNumQueries(7):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
data = response.json()
self.assertIn('id', data)
df = ModelVersion.objects.get(id=data['id'])
self.assertDictEqual(
data,
{
'id': str(df.id),
'model_id': str(self.model1.id),
'parent': None,
'description': '',
'state': ModelVersionState.Created.value,
'configuration': {},
'tag': None,
'size': request['size'],
'hash': request['hash'],
's3_url': s3_presigned_url_mock.return_value,
's3_put_url': s3_presigned_url_mock.return_value
}
)
@patch('arkindex.project.aws.s3.meta.client.generate_presigned_url')
def test_model_version_creation_with_tag(self, s3_presigned_url_mock):
"""
Creates a new model version with a tag
"""
self.client.force_login(self.user1)
s3_presigned_url_mock.return_value = 'http://s3/upload_put_url'
request = {
'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0',
'size': 8,
'tag': 'TAG'
}
with self.assertNumQueries(8):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
data = response.json()
self.assertIn('id', data)
df = ModelVersion.objects.get(id=data['id'])
self.assertDictEqual(
data,
{
'id': str(df.id),
'model_id': str(self.model1.id),
'parent': None,
'description': '',
'state': ModelVersionState.Created.value,
'configuration': {},
'tag': request['tag'],
'size': request['size'],
'hash': request['hash'],
's3_url': s3_presigned_url_mock.return_value,
's3_put_url': s3_presigned_url_mock.return_value
}
)
def test_create_model_version_unique(self):
"""
Raises 400 when creating a model version that already exists,
same model_id and tag
"""
self.client.force_login(self.user1)
ModelVersion.objects.create(model=self.model1, tag='production', hash='5a50cdbaf05d3b6cc51fcb173d0057c0', size=16)
with self.assertNumQueries(7):
response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), {
'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0',
'size': 8,
'tag': 'production'
})
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment