From 7e357eeb11f53f2b35b6bf7689d1a3557a3b988b Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Thu, 24 Mar 2022 08:57:17 +0000 Subject: [PATCH] Create model version endpoint --- arkindex/project/api_v1.py | 4 + arkindex/project/config.py | 1 + arkindex/project/settings.py | 5 + .../tests/config_samples/defaults.yaml | 1 + .../tests/config_samples/override.yaml | 1 + arkindex/training/__init__.py | 0 arkindex/training/api.py | 41 +++++ .../migrations/0002_alter_modelversion_tag.py | 18 ++ arkindex/training/models.py | 12 +- arkindex/training/serializers.py | 65 +++++++ arkindex/training/tests/__init__.py | 0 .../training/tests/test_model_versions.py | 160 ++++++++++++++++++ 12 files changed, 306 insertions(+), 2 deletions(-) create mode 100644 arkindex/training/__init__.py create mode 100644 arkindex/training/api.py create mode 100644 arkindex/training/migrations/0002_alter_modelversion_tag.py create mode 100644 arkindex/training/serializers.py create mode 100644 arkindex/training/tests/__init__.py create mode 100644 arkindex/training/tests/test_model_versions.py diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py index d7e345a2c9..b844cf4053 100644 --- a/arkindex/project/api_v1.py +++ b/arkindex/project/api_v1.py @@ -88,6 +88,7 @@ from arkindex.documents.api.ml import ( from arkindex.documents.api.search import CorpusSearch from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve from arkindex.project.openapi import OpenApiSchemaView +from arkindex.training.api import ModelVersionsCreate from arkindex.users.api import ( CredentialsList, CredentialsRetrieve, @@ -235,6 +236,9 @@ api = [ path('process/<uuid:pk>/template/', CreateProcessTemplate.as_view(), name='create-process-template'), path('process/<uuid:pk>/apply/', ApplyProcessTemplate.as_view(), name='apply-process-template'), + # ML models training + path('model/<uuid:pk>/versions/', ModelVersionsCreate.as_view(), name='model-version-create'), + # Image management path('image/', ImageCreate.as_view(), name='image-create'), path('image/iiif/url/', IIIFURLCreate.as_view(), name='iiif-url-create'), diff --git a/arkindex/project/config.py b/arkindex/project/config.py index e7a567193c..857e0afc7d 100644 --- a/arkindex/project/config.py +++ b/arkindex/project/config.py @@ -195,6 +195,7 @@ def get_settings_parser(base_dir): s3_parser.add_option('thumbnails_bucket', type=str, default='thumbnails') s3_parser.add_option('staging_bucket', type=str, default='staging') s3_parser.add_option('export_bucket', type=str, default='export') + s3_parser.add_option('training_bucket', type=str, default='training') s3_parser.add_option('ponos_logs_bucket', type=str, default='ponos-logs') s3_parser.add_option('ponos_artifacts_bucket', type=str, default='ponos-artifacts') diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py index f476cc5fc6..6b65c56673 100644 --- a/arkindex/project/settings.py +++ b/arkindex/project/settings.py @@ -271,6 +271,10 @@ SPECTACULAR_SETTINGS = { 'name': 'ml', 'description': 'Machine Learning tools and results', }, + { + 'name': 'training', + 'description': 'Machine Learning models training', + }, {'name': 'oauth'}, {'name': 'ponos'}, {'name': 'repos'}, @@ -488,6 +492,7 @@ PONOS_S3_ARTIFACTS_BUCKET = conf['s3']['ponos_artifacts_bucket'] AWS_THUMBNAIL_BUCKET = conf['s3']['thumbnails_bucket'] AWS_STAGING_BUCKET = conf['s3']['staging_bucket'] AWS_EXPORT_BUCKET = conf['s3']['export_bucket'] +AWS_TRAINING_BUCKET = conf['s3']['training_bucket'] # Ponos integration _ponos_env = { diff --git a/arkindex/project/tests/config_samples/defaults.yaml b/arkindex/project/tests/config_samples/defaults.yaml index eabfd74749..a6d0507367 100644 --- a/arkindex/project/tests/config_samples/defaults.yaml +++ b/arkindex/project/tests/config_samples/defaults.yaml @@ -78,6 +78,7 @@ s3: secret_access_key: null staging_bucket: staging thumbnails_bucket: thumbnails + training_bucket: training secret_key: jf0w^y&ml(caax8f&a1mub)(js9(l5mhbbhosz3gi+m01ex+lo sentry: dsn: null diff --git a/arkindex/project/tests/config_samples/override.yaml b/arkindex/project/tests/config_samples/override.yaml index 884d8c05de..f4253c6791 100644 --- a/arkindex/project/tests/config_samples/override.yaml +++ b/arkindex/project/tests/config_samples/override.yaml @@ -93,6 +93,7 @@ s3: secret_access_key: hunter2 staging_bucket: dropboxbutworse thumbnails_bucket: toenails + training_bucket: nachoneko secret_key: abcdef sentry: dsn: https://nowhere diff --git a/arkindex/training/__init__.py b/arkindex/training/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/arkindex/training/api.py b/arkindex/training/api.py new file mode 100644 index 0000000000..11cef28a67 --- /dev/null +++ b/arkindex/training/api.py @@ -0,0 +1,41 @@ +from drf_spectacular.utils import extend_schema, extend_schema_view +from rest_framework.exceptions import PermissionDenied +from rest_framework.generics import CreateAPIView + +from arkindex.project.mixins import TrainingModelMixin +from arkindex.project.permissions import IsVerified +from arkindex.training.models import Model +from arkindex.training.serializers import ModelVersionCreateSerializer + + +@extend_schema(tags=['training']) +@extend_schema_view( + post=extend_schema( + operation_id='CreateModelVersion', + description=( + 'Create a new version for a Machine Learning model.\n\n' + 'Requires a **contributor** access to the model.' + ), + responses={ + 200: ModelVersionCreateSerializer, 403: None + }, + ), +) +class ModelVersionsCreate(TrainingModelMixin, CreateAPIView): + permission_classes = (IsVerified, ) + serializer_class = ModelVersionCreateSerializer + queryset = Model.objects.none() + + def get_queryset(self): + return Model.objects.filter(id=self.kwargs['pk']) + + def get_serializer_context(self): + context = super().get_serializer_context() + context['model'] = self.get_object() + return context + + def check_object_permissions(self, request, model): + super().check_object_permissions(request, model) + + if not self.has_write_access(model): + raise PermissionDenied(detail='You need a Contributor access to the model to create a new version.') diff --git a/arkindex/training/migrations/0002_alter_modelversion_tag.py b/arkindex/training/migrations/0002_alter_modelversion_tag.py new file mode 100644 index 0000000000..156b5506b4 --- /dev/null +++ b/arkindex/training/migrations/0002_alter_modelversion_tag.py @@ -0,0 +1,18 @@ +# Generated by Django 4.0.2 on 2022-03-23 14:06 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('training', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='modelversion', + name='tag', + field=models.CharField(blank=True, default=None, max_length=50, null=True), + ), + ] diff --git a/arkindex/training/models.py b/arkindex/training/models.py index e51be834ad..416cb49acc 100644 --- a/arkindex/training/models.py +++ b/arkindex/training/models.py @@ -1,7 +1,10 @@ +from django.conf import settings from django.contrib.contenttypes.fields import GenericRelation from django.db import models +from django.utils.functional import cached_property from enumfields import Enum, EnumField +from arkindex.project.aws import S3FileMixin from arkindex.project.fields import MD5HashField from arkindex.project.models import IndexableModel @@ -36,7 +39,7 @@ class ModelVersionState(Enum): Error = 'error' -class ModelVersion(IndexableModel): +class ModelVersion(S3FileMixin, IndexableModel): """ A specific Model version """ @@ -46,7 +49,7 @@ class ModelVersion(IndexableModel): description = models.TextField(default="") - tag = models.CharField(null=True, max_length=50, blank=True) + tag = models.CharField(null=True, max_length=50, blank=True, default=None) state = EnumField(ModelVersionState, default=ModelVersionState.Created) @@ -63,3 +66,8 @@ class ModelVersion(IndexableModel): unique_together = ( ('model', 'tag'), ) + s3_bucket = settings.AWS_TRAINING_BUCKET + + @cached_property + def s3_key(self): + return f'{self.id}.zst' diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py new file mode 100644 index 0000000000..2b48c6bcb3 --- /dev/null +++ b/arkindex/training/serializers.py @@ -0,0 +1,65 @@ + +import re + +from drf_spectacular.utils import extend_schema_field +from rest_framework import serializers +from rest_framework.validators import UniqueTogetherValidator + +from arkindex.project.mixins import TrainingModelMixin +from arkindex.project.serializer_fields import EnumField +from arkindex.training.models import ModelVersion, ModelVersionState + + +def _model_from_context(serializer_field): + return serializer_field.context.get('model') + + +_model_from_context.requires_context = True + + +class ModelVersionSerializer(TrainingModelMixin, serializers.ModelSerializer): + model = serializers.HiddenField(default=_model_from_context) + tag = serializers.CharField(allow_blank=True, allow_null=True, max_length=50, required=False, default=None) + parent = serializers.PrimaryKeyRelatedField(queryset=ModelVersion.objects.none(), default=None) + state = EnumField(ModelVersionState, read_only=True) + s3_url = serializers.SerializerMethodField() + + class Meta: + model = ModelVersion + fields = ('id', 'model', 'model_id', 'parent', 'description', 'tag', 'hash', 'state', 'size', 'configuration', 's3_url') + read_only_fields = ('id', 'model_id', 'parent', 'state', 'hash', 'size', 's3_url') + + validators = [ + UniqueTogetherValidator( + queryset=ModelVersion.objects.filter(tag__isnull=False), + fields=['model', 'tag'], + message='A version for this model with this tag already exists.' + ) + ] + + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_s3_url(self, obj): + return obj.s3_url + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # When used for CreateModelVersion, the API adds the provided model to the context + if self.context.get('model') is not None: + self.fields['parent'].queryset = ModelVersion.objects.filter(model_id=self.context['model'].id) + + +class ModelVersionCreateSerializer(ModelVersionSerializer): + """ + Create a new Model Version by providing the hash and the size of the archive + """ + hash = serializers.RegexField(re.compile(r'[0-9A-Fa-f]{32}'), min_length=32, max_length=32) + size = serializers.IntegerField(min_value=0) + s3_put_url = serializers.SerializerMethodField() + + class Meta(ModelVersionSerializer.Meta): + fields = ModelVersionSerializer.Meta.fields + ('s3_put_url',) + read_only_fields = ModelVersionSerializer.Meta.read_only_fields + ('s3_put_url',) + + @extend_schema_field(serializers.CharField(allow_null=True)) + def get_s3_put_url(self, obj): + return obj.s3_put_url diff --git a/arkindex/training/tests/__init__.py b/arkindex/training/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/arkindex/training/tests/test_model_versions.py b/arkindex/training/tests/test_model_versions.py new file mode 100644 index 0000000000..c8ebe414e0 --- /dev/null +++ b/arkindex/training/tests/test_model_versions.py @@ -0,0 +1,160 @@ + + +from unittest.mock import patch + +from django.urls import reverse +from rest_framework import status + +from arkindex.project.tests import FixtureAPITestCase +from arkindex.training.models import Model, ModelVersion, ModelVersionState +from arkindex.users.models import Group, Right, User + + +class TestModelVersion(FixtureAPITestCase): + """ + Test model versions + """ + + @classmethod + def setUpTestData(cls): + r"""We use a simple rights configuration for those tests + + User1 User2 User3 + | / / + | 100 / + | | / + 100 Group1 10 + | / \ / + | 50 100 / + | / \ / + Model1 Model2 + """ + super().setUpTestData() + + # Create there models + cls.model1 = Model.objects.create(name="First Model") + cls.model2 = Model.objects.create(name="Second Model") + cls.model3 = Model.objects.create(name="Third Model") + + # Create three users with different access rights on each model + cls.user1 = User.objects.create(email='user1@test.test', display_name='User 1', verified_email=True) + cls.user2 = User.objects.create(email='user2@test.test', display_name='User 2', verified_email=True) + cls.user3 = User.objects.create(email='user3@test.test', display_name='User 3', verified_email=True) + + # Create the group + cls.group1 = Group.objects.create(name="Group1") + + # Access rights + Right.objects.bulk_create([ + Right(user=cls.user1, content_object=cls.model1, level=100), + Right(user=cls.user2, content_object=cls.group1, level=100), + Right(user=cls.user3, content_object=cls.model2, level=10), + Right(group=cls.group1, content_object=cls.model1, level=50), + Right(group=cls.group1, content_object=cls.model2, level=100), + ]) + + def test_create_model_version_requires_verified(self): + user = User.objects.create(email='not_verified@mail.com', display_name='Not Verified', verified_email=False) + self.client.force_login(user) + with self.assertNumQueries(2): + response = self.client.post(reverse("api:model-version-create", kwargs={"pk": str(self.model2.id)}), {'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0', 'size': 8}) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."}) + + def test_model_version_requires_contributor(self): + """ + Can't create model version as guest + """ + self.client.force_login(self.user1) + with self.assertNumQueries(6): + response = self.client.post(reverse("api:model-version-create", kwargs={"pk": str(self.model2.id)}), {'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0', 'size': 8}) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the model to create a new version."}) + + @patch('arkindex.project.aws.s3.meta.client.generate_presigned_url') + def test_model_version_creation_no_tag(self, s3_presigned_url_mock): + """ + Creates a new model version without setting a tag + """ + self.client.force_login(self.user1) + s3_presigned_url_mock.return_value = 'http://s3/upload_put_url' + request = { + 'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0', + 'size': 8 + } + with self.assertNumQueries(7): + response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + data = response.json() + self.assertIn('id', data) + df = ModelVersion.objects.get(id=data['id']) + + self.assertDictEqual( + data, + { + 'id': str(df.id), + 'model_id': str(self.model1.id), + 'parent': None, + 'description': '', + 'state': ModelVersionState.Created.value, + 'configuration': {}, + 'tag': None, + 'size': request['size'], + 'hash': request['hash'], + 's3_url': s3_presigned_url_mock.return_value, + 's3_put_url': s3_presigned_url_mock.return_value + } + ) + + @patch('arkindex.project.aws.s3.meta.client.generate_presigned_url') + def test_model_version_creation_with_tag(self, s3_presigned_url_mock): + """ + Creates a new model version with a tag + """ + self.client.force_login(self.user1) + s3_presigned_url_mock.return_value = 'http://s3/upload_put_url' + request = { + 'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0', + 'size': 8, + 'tag': 'TAG' + } + with self.assertNumQueries(8): + response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + data = response.json() + self.assertIn('id', data) + df = ModelVersion.objects.get(id=data['id']) + + self.assertDictEqual( + data, + { + 'id': str(df.id), + 'model_id': str(self.model1.id), + 'parent': None, + 'description': '', + 'state': ModelVersionState.Created.value, + 'configuration': {}, + 'tag': request['tag'], + 'size': request['size'], + 'hash': request['hash'], + 's3_url': s3_presigned_url_mock.return_value, + 's3_put_url': s3_presigned_url_mock.return_value + } + ) + + def test_create_model_version_unique(self): + """ + Raises 400 when creating a model version that already exists, + same model_id and tag + """ + self.client.force_login(self.user1) + ModelVersion.objects.create(model=self.model1, tag='production', hash='5a50cdbaf05d3b6cc51fcb173d0057c0', size=16) + with self.assertNumQueries(7): + response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), { + 'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0', + 'size': 8, + 'tag': 'production' + }) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]}) -- GitLab