From 7e357eeb11f53f2b35b6bf7689d1a3557a3b988b Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Thu, 24 Mar 2022 08:57:17 +0000
Subject: [PATCH] Create model version endpoint

---
 arkindex/project/api_v1.py                    |   4 +
 arkindex/project/config.py                    |   1 +
 arkindex/project/settings.py                  |   5 +
 .../tests/config_samples/defaults.yaml        |   1 +
 .../tests/config_samples/override.yaml        |   1 +
 arkindex/training/__init__.py                 |   0
 arkindex/training/api.py                      |  41 +++++
 .../migrations/0002_alter_modelversion_tag.py |  18 ++
 arkindex/training/models.py                   |  12 +-
 arkindex/training/serializers.py              |  65 +++++++
 arkindex/training/tests/__init__.py           |   0
 .../training/tests/test_model_versions.py     | 160 ++++++++++++++++++
 12 files changed, 306 insertions(+), 2 deletions(-)
 create mode 100644 arkindex/training/__init__.py
 create mode 100644 arkindex/training/api.py
 create mode 100644 arkindex/training/migrations/0002_alter_modelversion_tag.py
 create mode 100644 arkindex/training/serializers.py
 create mode 100644 arkindex/training/tests/__init__.py
 create mode 100644 arkindex/training/tests/test_model_versions.py

diff --git a/arkindex/project/api_v1.py b/arkindex/project/api_v1.py
index d7e345a2c9..b844cf4053 100644
--- a/arkindex/project/api_v1.py
+++ b/arkindex/project/api_v1.py
@@ -88,6 +88,7 @@ from arkindex.documents.api.ml import (
 from arkindex.documents.api.search import CorpusSearch
 from arkindex.images.api import IIIFInformationCreate, IIIFURLCreate, ImageCreate, ImageElements, ImageRetrieve
 from arkindex.project.openapi import OpenApiSchemaView
+from arkindex.training.api import ModelVersionsCreate
 from arkindex.users.api import (
     CredentialsList,
     CredentialsRetrieve,
@@ -235,6 +236,9 @@ api = [
     path('process/<uuid:pk>/template/', CreateProcessTemplate.as_view(), name='create-process-template'),
     path('process/<uuid:pk>/apply/', ApplyProcessTemplate.as_view(), name='apply-process-template'),
 
+    # ML models training
+    path('model/<uuid:pk>/versions/', ModelVersionsCreate.as_view(), name='model-version-create'),
+
     # Image management
     path('image/', ImageCreate.as_view(), name='image-create'),
     path('image/iiif/url/', IIIFURLCreate.as_view(), name='iiif-url-create'),
diff --git a/arkindex/project/config.py b/arkindex/project/config.py
index e7a567193c..857e0afc7d 100644
--- a/arkindex/project/config.py
+++ b/arkindex/project/config.py
@@ -195,6 +195,7 @@ def get_settings_parser(base_dir):
     s3_parser.add_option('thumbnails_bucket', type=str, default='thumbnails')
     s3_parser.add_option('staging_bucket', type=str, default='staging')
     s3_parser.add_option('export_bucket', type=str, default='export')
+    s3_parser.add_option('training_bucket', type=str, default='training')
     s3_parser.add_option('ponos_logs_bucket', type=str, default='ponos-logs')
     s3_parser.add_option('ponos_artifacts_bucket', type=str, default='ponos-artifacts')
 
diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py
index f476cc5fc6..6b65c56673 100644
--- a/arkindex/project/settings.py
+++ b/arkindex/project/settings.py
@@ -271,6 +271,10 @@ SPECTACULAR_SETTINGS = {
             'name': 'ml',
             'description': 'Machine Learning tools and results',
         },
+        {
+            'name': 'training',
+            'description': 'Machine Learning models training',
+        },
         {'name': 'oauth'},
         {'name': 'ponos'},
         {'name': 'repos'},
@@ -488,6 +492,7 @@ PONOS_S3_ARTIFACTS_BUCKET = conf['s3']['ponos_artifacts_bucket']
 AWS_THUMBNAIL_BUCKET = conf['s3']['thumbnails_bucket']
 AWS_STAGING_BUCKET = conf['s3']['staging_bucket']
 AWS_EXPORT_BUCKET = conf['s3']['export_bucket']
+AWS_TRAINING_BUCKET = conf['s3']['training_bucket']
 
 # Ponos integration
 _ponos_env = {
diff --git a/arkindex/project/tests/config_samples/defaults.yaml b/arkindex/project/tests/config_samples/defaults.yaml
index eabfd74749..a6d0507367 100644
--- a/arkindex/project/tests/config_samples/defaults.yaml
+++ b/arkindex/project/tests/config_samples/defaults.yaml
@@ -78,6 +78,7 @@ s3:
   secret_access_key: null
   staging_bucket: staging
   thumbnails_bucket: thumbnails
+  training_bucket: training
 secret_key: jf0w^y&ml(caax8f&a1mub)(js9(l5mhbbhosz3gi+m01ex+lo
 sentry:
   dsn: null
diff --git a/arkindex/project/tests/config_samples/override.yaml b/arkindex/project/tests/config_samples/override.yaml
index 884d8c05de..f4253c6791 100644
--- a/arkindex/project/tests/config_samples/override.yaml
+++ b/arkindex/project/tests/config_samples/override.yaml
@@ -93,6 +93,7 @@ s3:
   secret_access_key: hunter2
   staging_bucket: dropboxbutworse
   thumbnails_bucket: toenails
+  training_bucket: nachoneko
 secret_key: abcdef
 sentry:
   dsn: https://nowhere
diff --git a/arkindex/training/__init__.py b/arkindex/training/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/arkindex/training/api.py b/arkindex/training/api.py
new file mode 100644
index 0000000000..11cef28a67
--- /dev/null
+++ b/arkindex/training/api.py
@@ -0,0 +1,41 @@
+from drf_spectacular.utils import extend_schema, extend_schema_view
+from rest_framework.exceptions import PermissionDenied
+from rest_framework.generics import CreateAPIView
+
+from arkindex.project.mixins import TrainingModelMixin
+from arkindex.project.permissions import IsVerified
+from arkindex.training.models import Model
+from arkindex.training.serializers import ModelVersionCreateSerializer
+
+
+@extend_schema(tags=['training'])
+@extend_schema_view(
+    post=extend_schema(
+        operation_id='CreateModelVersion',
+        description=(
+            'Create a new version for a Machine Learning model.\n\n'
+            'Requires a **contributor** access to the model.'
+        ),
+        responses={
+            200: ModelVersionCreateSerializer, 403: None
+        },
+    ),
+)
+class ModelVersionsCreate(TrainingModelMixin, CreateAPIView):
+    permission_classes = (IsVerified, )
+    serializer_class = ModelVersionCreateSerializer
+    queryset = Model.objects.none()
+
+    def get_queryset(self):
+        return Model.objects.filter(id=self.kwargs['pk'])
+
+    def get_serializer_context(self):
+        context = super().get_serializer_context()
+        context['model'] = self.get_object()
+        return context
+
+    def check_object_permissions(self, request, model):
+        super().check_object_permissions(request, model)
+
+        if not self.has_write_access(model):
+            raise PermissionDenied(detail='You need a Contributor access to the model to create a new version.')
diff --git a/arkindex/training/migrations/0002_alter_modelversion_tag.py b/arkindex/training/migrations/0002_alter_modelversion_tag.py
new file mode 100644
index 0000000000..156b5506b4
--- /dev/null
+++ b/arkindex/training/migrations/0002_alter_modelversion_tag.py
@@ -0,0 +1,18 @@
+# Generated by Django 4.0.2 on 2022-03-23 14:06
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('training', '0001_initial'),
+    ]
+
+    operations = [
+        migrations.AlterField(
+            model_name='modelversion',
+            name='tag',
+            field=models.CharField(blank=True, default=None, max_length=50, null=True),
+        ),
+    ]
diff --git a/arkindex/training/models.py b/arkindex/training/models.py
index e51be834ad..416cb49acc 100644
--- a/arkindex/training/models.py
+++ b/arkindex/training/models.py
@@ -1,7 +1,10 @@
+from django.conf import settings
 from django.contrib.contenttypes.fields import GenericRelation
 from django.db import models
+from django.utils.functional import cached_property
 from enumfields import Enum, EnumField
 
+from arkindex.project.aws import S3FileMixin
 from arkindex.project.fields import MD5HashField
 from arkindex.project.models import IndexableModel
 
@@ -36,7 +39,7 @@ class ModelVersionState(Enum):
     Error = 'error'
 
 
-class ModelVersion(IndexableModel):
+class ModelVersion(S3FileMixin, IndexableModel):
     """
     A specific Model version
     """
@@ -46,7 +49,7 @@ class ModelVersion(IndexableModel):
 
     description = models.TextField(default="")
 
-    tag = models.CharField(null=True, max_length=50, blank=True)
+    tag = models.CharField(null=True, max_length=50, blank=True, default=None)
 
     state = EnumField(ModelVersionState, default=ModelVersionState.Created)
 
@@ -63,3 +66,8 @@ class ModelVersion(IndexableModel):
         unique_together = (
             ('model', 'tag'),
         )
+    s3_bucket = settings.AWS_TRAINING_BUCKET
+
+    @cached_property
+    def s3_key(self):
+        return f'{self.id}.zst'
diff --git a/arkindex/training/serializers.py b/arkindex/training/serializers.py
new file mode 100644
index 0000000000..2b48c6bcb3
--- /dev/null
+++ b/arkindex/training/serializers.py
@@ -0,0 +1,65 @@
+
+import re
+
+from drf_spectacular.utils import extend_schema_field
+from rest_framework import serializers
+from rest_framework.validators import UniqueTogetherValidator
+
+from arkindex.project.mixins import TrainingModelMixin
+from arkindex.project.serializer_fields import EnumField
+from arkindex.training.models import ModelVersion, ModelVersionState
+
+
+def _model_from_context(serializer_field):
+    return serializer_field.context.get('model')
+
+
+_model_from_context.requires_context = True
+
+
+class ModelVersionSerializer(TrainingModelMixin, serializers.ModelSerializer):
+    model = serializers.HiddenField(default=_model_from_context)
+    tag = serializers.CharField(allow_blank=True, allow_null=True, max_length=50, required=False, default=None)
+    parent = serializers.PrimaryKeyRelatedField(queryset=ModelVersion.objects.none(), default=None)
+    state = EnumField(ModelVersionState, read_only=True)
+    s3_url = serializers.SerializerMethodField()
+
+    class Meta:
+        model = ModelVersion
+        fields = ('id', 'model', 'model_id', 'parent', 'description', 'tag', 'hash', 'state', 'size', 'configuration', 's3_url')
+        read_only_fields = ('id', 'model_id', 'parent', 'state', 'hash', 'size', 's3_url')
+
+        validators = [
+            UniqueTogetherValidator(
+                queryset=ModelVersion.objects.filter(tag__isnull=False),
+                fields=['model', 'tag'],
+                message='A version for this model with this tag already exists.'
+            )
+        ]
+
+    @extend_schema_field(serializers.CharField(allow_null=True))
+    def get_s3_url(self, obj):
+        return obj.s3_url
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # When used for CreateModelVersion, the API adds the provided model to the context
+        if self.context.get('model') is not None:
+            self.fields['parent'].queryset = ModelVersion.objects.filter(model_id=self.context['model'].id)
+
+
+class ModelVersionCreateSerializer(ModelVersionSerializer):
+    """
+    Create a new Model Version by providing the hash and the size of the archive
+    """
+    hash = serializers.RegexField(re.compile(r'[0-9A-Fa-f]{32}'), min_length=32, max_length=32)
+    size = serializers.IntegerField(min_value=0)
+    s3_put_url = serializers.SerializerMethodField()
+
+    class Meta(ModelVersionSerializer.Meta):
+        fields = ModelVersionSerializer.Meta.fields + ('s3_put_url',)
+        read_only_fields = ModelVersionSerializer.Meta.read_only_fields + ('s3_put_url',)
+
+    @extend_schema_field(serializers.CharField(allow_null=True))
+    def get_s3_put_url(self, obj):
+        return obj.s3_put_url
diff --git a/arkindex/training/tests/__init__.py b/arkindex/training/tests/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/arkindex/training/tests/test_model_versions.py b/arkindex/training/tests/test_model_versions.py
new file mode 100644
index 0000000000..c8ebe414e0
--- /dev/null
+++ b/arkindex/training/tests/test_model_versions.py
@@ -0,0 +1,160 @@
+
+
+from unittest.mock import patch
+
+from django.urls import reverse
+from rest_framework import status
+
+from arkindex.project.tests import FixtureAPITestCase
+from arkindex.training.models import Model, ModelVersion, ModelVersionState
+from arkindex.users.models import Group, Right, User
+
+
+class TestModelVersion(FixtureAPITestCase):
+    """
+    Test model versions
+    """
+
+    @classmethod
+    def setUpTestData(cls):
+        r"""We use a simple rights configuration for those tests
+
+            User1     User2       User3
+              |        /           /
+              |      100          /
+              |       |          /
+             100   Group1      10
+              |     / \        /
+              |   50   100    /
+              |  /       \   /
+            Model1      Model2
+        """
+        super().setUpTestData()
+
+        # Create there models
+        cls.model1 = Model.objects.create(name="First Model")
+        cls.model2 = Model.objects.create(name="Second Model")
+        cls.model3 = Model.objects.create(name="Third Model")
+
+        # Create three users with different access rights on each model
+        cls.user1 = User.objects.create(email='user1@test.test', display_name='User 1', verified_email=True)
+        cls.user2 = User.objects.create(email='user2@test.test', display_name='User 2', verified_email=True)
+        cls.user3 = User.objects.create(email='user3@test.test', display_name='User 3', verified_email=True)
+
+        # Create the group
+        cls.group1 = Group.objects.create(name="Group1")
+
+        # Access rights
+        Right.objects.bulk_create([
+            Right(user=cls.user1, content_object=cls.model1, level=100),
+            Right(user=cls.user2, content_object=cls.group1, level=100),
+            Right(user=cls.user3, content_object=cls.model2, level=10),
+            Right(group=cls.group1, content_object=cls.model1, level=50),
+            Right(group=cls.group1, content_object=cls.model2, level=100),
+        ])
+
+    def test_create_model_version_requires_verified(self):
+        user = User.objects.create(email='not_verified@mail.com', display_name='Not Verified', verified_email=False)
+        self.client.force_login(user)
+        with self.assertNumQueries(2):
+            response = self.client.post(reverse("api:model-version-create", kwargs={"pk": str(self.model2.id)}), {'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0', 'size': 8})
+        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+        self.assertDictEqual(response.json(), {"detail": "You do not have permission to perform this action."})
+
+    def test_model_version_requires_contributor(self):
+        """
+        Can't create model version as guest
+        """
+        self.client.force_login(self.user1)
+        with self.assertNumQueries(6):
+            response = self.client.post(reverse("api:model-version-create", kwargs={"pk": str(self.model2.id)}), {'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0', 'size': 8})
+        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+        self.assertDictEqual(response.json(), {"detail": "You need a Contributor access to the model to create a new version."})
+
+    @patch('arkindex.project.aws.s3.meta.client.generate_presigned_url')
+    def test_model_version_creation_no_tag(self, s3_presigned_url_mock):
+        """
+        Creates a new model version without setting a tag
+        """
+        self.client.force_login(self.user1)
+        s3_presigned_url_mock.return_value = 'http://s3/upload_put_url'
+        request = {
+            'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0',
+            'size': 8
+        }
+        with self.assertNumQueries(7):
+            response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request)
+        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+        data = response.json()
+        self.assertIn('id', data)
+        df = ModelVersion.objects.get(id=data['id'])
+
+        self.assertDictEqual(
+            data,
+            {
+                'id': str(df.id),
+                'model_id': str(self.model1.id),
+                'parent': None,
+                'description': '',
+                'state': ModelVersionState.Created.value,
+                'configuration': {},
+                'tag': None,
+                'size': request['size'],
+                'hash': request['hash'],
+                's3_url': s3_presigned_url_mock.return_value,
+                's3_put_url': s3_presigned_url_mock.return_value
+            }
+        )
+
+    @patch('arkindex.project.aws.s3.meta.client.generate_presigned_url')
+    def test_model_version_creation_with_tag(self, s3_presigned_url_mock):
+        """
+        Creates a new model version with a tag
+        """
+        self.client.force_login(self.user1)
+        s3_presigned_url_mock.return_value = 'http://s3/upload_put_url'
+        request = {
+            'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0',
+            'size': 8,
+            'tag': 'TAG'
+        }
+        with self.assertNumQueries(8):
+            response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), request)
+        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+        data = response.json()
+        self.assertIn('id', data)
+        df = ModelVersion.objects.get(id=data['id'])
+
+        self.assertDictEqual(
+            data,
+            {
+                'id': str(df.id),
+                'model_id': str(self.model1.id),
+                'parent': None,
+                'description': '',
+                'state': ModelVersionState.Created.value,
+                'configuration': {},
+                'tag': request['tag'],
+                'size': request['size'],
+                'hash': request['hash'],
+                's3_url': s3_presigned_url_mock.return_value,
+                's3_put_url': s3_presigned_url_mock.return_value
+            }
+        )
+
+    def test_create_model_version_unique(self):
+        """
+        Raises 400 when creating a model version that already exists,
+        same model_id and tag
+        """
+        self.client.force_login(self.user1)
+        ModelVersion.objects.create(model=self.model1, tag='production', hash='5a50cdbaf05d3b6cc51fcb173d0057c0', size=16)
+        with self.assertNumQueries(7):
+            response = self.client.post(reverse('api:model-version-create', kwargs={"pk": str(self.model1.id)}), {
+                'hash': '05a5cdbaf05d3b6cc51fcb173d0057c0',
+                'size': 8,
+                'tag': 'production'
+            })
+
+        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertDictEqual(response.json(), {"non_field_errors": ["A version for this model with this tag already exists."]})
-- 
GitLab