From c51e18ef522f8a9d6037809882b591743c3433d4 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Mon, 26 Oct 2020 15:30:45 +0000
Subject: [PATCH] CreateTranscriptions endpoint

---
 arkindex/documents/api/ml.py                  |  26 ++---
 arkindex/documents/serializers/ml.py          |  50 +++++++++
 .../tests/test_bulk_transcriptions.py         | 106 ++++++++++++++++++
 arkindex/project/openapi/patch.yml            |   8 --
 4 files changed, 167 insertions(+), 23 deletions(-)
 create mode 100644 arkindex/documents/tests/test_bulk_transcriptions.py

diff --git a/arkindex/documents/api/ml.py b/arkindex/documents/api/ml.py
index 68dff4af39..fd3820662c 100644
--- a/arkindex/documents/api/ml.py
+++ b/arkindex/documents/api/ml.py
@@ -3,8 +3,8 @@ from django.db import transaction
 from django.db.models import Q, Count
 from rest_framework import status
 from rest_framework.generics import (
-    GenericAPIView, ListAPIView, ListCreateAPIView,
-    CreateAPIView, UpdateAPIView, RetrieveDestroyAPIView, RetrieveUpdateDestroyAPIView
+    GenericAPIView, ListAPIView, ListCreateAPIView, CreateAPIView,
+    RetrieveDestroyAPIView, RetrieveUpdateDestroyAPIView
 )
 from rest_framework.exceptions import PermissionDenied, ValidationError
 from rest_framework.response import Response
@@ -13,9 +13,10 @@ from arkindex.documents.models import \
 from arkindex_common.ml_tool import MLToolType
 from arkindex.documents.serializers.ml import (
     ClassificationsSerializer, ClassificationCreateSerializer, ClassificationSerializer,
-    TranscriptionSerializer, TranscriptionCreateSerializer, ElementTranscriptionsBulkSerializer,
-    DataSourceStatsSerializer, ClassificationsSelectionSerializer, ClassificationMode,
-    CountMLClassSerializer, AnnotatedElementSerializer
+    ClassificationsSelectionSerializer, ClassificationMode,
+    TranscriptionSerializer, TranscriptionCreateSerializer, TranscriptionBulkSerializer,
+    ElementTranscriptionsBulkSerializer, AnnotatedElementSerializer,
+    DataSourceStatsSerializer, CountMLClassSerializer
 )
 from arkindex.images.models import Zone
 from arkindex.project.filters import SafeSearchFilter
@@ -292,21 +293,16 @@ class ElementTranscriptionsBulk(CreateAPIView):
         return annotations
 
 
-class TranscriptionBulk(DeprecatedMixin, CreateAPIView, UpdateAPIView):
+class TranscriptionBulk(CreateAPIView):
     '''
-    Create multiple transcriptions at once, all linked to the same page
-    and to the same recognizer.
+    Create multiple transcriptions at once on existing elements
     '''
-    # Force DRF to ignore PATCH
-    http_method_names = ['post', 'put', 'head', 'options', 'trace']
     openapi_overrides = {
+        'operationId': 'CreateTranscriptions',
         'tags': ['transcriptions'],
     }
-    deprecation_message = (
-        'Creating or updating transcriptions with their own zones is now deprecated. '
-        'Please use CreateElementTranscriptions to push transcriptions in bulk '
-        'attached to sub-elements.'
-    )
+    permission_classes = (IsVerified, )
+    serializer_class = TranscriptionBulkSerializer
 
 
 class CorpusMLClassList(CorpusACLMixin, ListCreateAPIView):
diff --git a/arkindex/documents/serializers/ml.py b/arkindex/documents/serializers/ml.py
index 40ef2f086a..7fa83582f1 100644
--- a/arkindex/documents/serializers/ml.py
+++ b/arkindex/documents/serializers/ml.py
@@ -369,6 +369,56 @@ class AnnotatedElementSerializer(serializers.Serializer):
     created = serializers.BooleanField(default=False)
 
 
+class TranscriptionBulkItemSerializer(serializers.Serializer):
+    # Element retrieval and checks is done in the BulkSerializer to avoid duplicate queries
+    element_id = serializers.UUIDField(
+        help_text='ID of an existing element to add the transcription to'
+    )
+    type = EnumField(TranscriptionType, help_text='Type of the transcription')
+    text = serializers.CharField()
+    score = serializers.FloatField(min_value=0, max_value=1)
+
+
+class TranscriptionBulkSerializer(serializers.Serializer):
+    worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all())
+    transcriptions = TranscriptionBulkItemSerializer(many=True)
+
+    def validate(self, data):
+        element_ids = set(transcription['element_id'] for transcription in data['transcriptions'])
+        found_ids = set(Element.objects.filter(
+            id__in=element_ids,
+            corpus__in=Corpus.objects.writable(self.context['request'].user)
+        ).values_list('id', flat=True))
+
+        missing_ids = element_ids - found_ids
+        if not missing_ids:
+            return data
+
+        # Return an error message as a list just like DRF's ListField, for easier debugging
+        raise ValidationError({'transcriptions': [
+            {"element_id": [f'Element {transcription["element_id"]} was not found or cannot be written to.']}
+            if transcription['element_id'] in missing_ids
+            else {}
+            for i, transcription in enumerate(data['transcriptions'])
+        ]})
+
+    def create(self, validated_data):
+        transcriptions = [
+            Transcription(
+                worker_version=validated_data['worker_version'],
+                element_id=transcription['element_id'],
+                type=transcription['type'],
+                text=transcription['text'],
+                score=transcription['score'],
+            )
+            for transcription in validated_data['transcriptions']
+        ]
+        Transcription.objects.bulk_create(transcriptions)
+
+        validated_data['transcriptions'] = transcriptions
+        return validated_data
+
+
 class ClassificationBulkSerializer(serializers.Serializer):
     """
     Single classification serializer for bulk insertion
diff --git a/arkindex/documents/tests/test_bulk_transcriptions.py b/arkindex/documents/tests/test_bulk_transcriptions.py
new file mode 100644
index 0000000000..12cb09e2b5
--- /dev/null
+++ b/arkindex/documents/tests/test_bulk_transcriptions.py
@@ -0,0 +1,106 @@
+from django.urls import reverse
+from rest_framework import status
+from arkindex_common.enums import TranscriptionType
+from arkindex.project.tests import FixtureAPITestCase
+from arkindex.dataimport.models import WorkerVersion
+
+
+class TestBulkTranscriptions(FixtureAPITestCase):
+
+    @classmethod
+    def setUpTestData(cls):
+        super().setUpTestData()
+        cls.worker_version = WorkerVersion.objects.get(worker__slug='reco')
+
+    def test_bulk_transcriptions_requires_login(self):
+        with self.assertNumQueries(0):
+            response = self.client.post(reverse('api:transcription-bulk'))
+        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+    def test_bulk_transcriptions_not_found(self):
+        self.client.force_login(self.user)
+        self.user.corpus_right.all().delete()
+        forbidden_element = self.corpus.elements.get(name='Volume 1, page 1r')
+        with self.assertNumQueries(4):
+            response = self.client.post(reverse('api:transcription-bulk'), {
+                "worker_version": str(self.worker_version.id),
+                "transcriptions": [
+                    {
+                        "element_id": str(forbidden_element.id),
+                        "type": TranscriptionType.Word.value,
+                        "text": "lol",
+                        "score": 0.4
+                    },
+                    {
+                        "element_id": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
+                        "type": TranscriptionType.Word.value,
+                        "text": "lol",
+                        "score": 0.4
+                    }
+                ],
+            }, format='json')
+        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+        self.assertEqual(response.json(), {
+            'transcriptions': [
+                {'element_id': [f'Element {forbidden_element.id} was not found or cannot be written to.']},
+                {'element_id': ['Element aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa was not found or cannot be written to.']},
+            ]
+        })
+
+    def test_bulk_transcriptions(self):
+        self.client.force_login(self.user)
+
+        element1 = self.corpus.elements.get(name='Volume 2')
+        element2 = self.corpus.elements.get(name='Volume 2, page 1r')
+        self.assertFalse(element1.transcriptions.exists())
+        self.assertFalse(element2.transcriptions.exists())
+
+        with self.assertNumQueries(5):
+            response = self.client.post(reverse('api:transcription-bulk'), {
+                "worker_version": str(self.worker_version.id),
+                "transcriptions": [
+                    {
+                        "element_id": str(element1.id),
+                        "type": TranscriptionType.Word.value,
+                        "text": "Sneasel",
+                        "score": 0.54
+                    },
+                    {
+                        "element_id": str(element2.id),
+                        "type": TranscriptionType.Line.value,
+                        "text": "Charizard",
+                        "score": 0.85
+                    },
+                    {
+                        "element_id": str(element1.id),
+                        "type": TranscriptionType.Word.value,
+                        "text": "Raticate",
+                        "score": 0.12
+                    },
+                ],
+            }, format='json')
+        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+
+        self.assertCountEqual(
+            list(element1.transcriptions.values('type', 'text', 'score')),
+            [
+                {
+                    "type": TranscriptionType.Word,
+                    "text": "Sneasel",
+                    "score": 0.54
+                },
+                {
+                    "type": TranscriptionType.Word,
+                    "text": "Raticate",
+                    "score": 0.12
+                },
+            ]
+        )
+        self.assertCountEqual(
+            list(element2.transcriptions.values('type', 'text', 'score')),
+            [{
+                "type": TranscriptionType.Line,
+                "text": "Charizard",
+                "score": 0.85
+            }]
+        )
diff --git a/arkindex/project/openapi/patch.yml b/arkindex/project/openapi/patch.yml
index 811ed843d3..66cc7c59da 100644
--- a/arkindex/project/openapi/patch.yml
+++ b/arkindex/project/openapi/patch.yml
@@ -283,14 +283,6 @@ paths:
       description: Update the text of a manual transcription
     delete:
       description: Delete a manual transcription
-  /api/v1/transcription/bulk/:
-    post:
-      operationId: CreateTranscriptions
-    put:
-      operationId: UpdateTranscriptions
-      description: >-
-        Replace all existing transcriptions from a given recognizer on a page
-        with other transcriptions.
   /api/v1/metadata/{id}/:
     get:
       operationId: RetrieveMetaData
-- 
GitLab