From c97e6df30c459adbd6f5052c8ef7201a7dd2f4a2 Mon Sep 17 00:00:00 2001
From: Erwan Rouchet <rouchet@teklia.com>
Date: Thu, 11 Apr 2019 15:17:11 +0000
Subject: [PATCH] Append transcriptions in bulk

---
 arkindex/documents/api/ml.py                  | 18 +++--
 .../tests/test_transcription_create.py        | 67 +++++++++++++++++++
 2 files changed, 81 insertions(+), 4 deletions(-)

diff --git a/arkindex/documents/api/ml.py b/arkindex/documents/api/ml.py
index 2377f00a3c..e824406cf1 100644
--- a/arkindex/documents/api/ml.py
+++ b/arkindex/documents/api/ml.py
@@ -1,6 +1,6 @@
 from django.conf import settings
 from rest_framework import status
-from rest_framework.generics import CreateAPIView
+from rest_framework.generics import CreateAPIView, UpdateAPIView
 from rest_framework.exceptions import ValidationError
 from rest_framework.response import Response
 from arkindex.documents.models import Classification, DataSource, Transcription, TranscriptionType
@@ -68,7 +68,7 @@ class TranscriptionCreate(CreateAPIView):
         return Response({'id': obj.id}, status=status.HTTP_201_CREATED, headers=headers)
 
 
-class TranscriptionBulk(CreateAPIView):
+class TranscriptionBulk(CreateAPIView, UpdateAPIView):
     '''
     Create transcriptions in bulk, all linked to the same image
     and parent element
@@ -76,8 +76,16 @@ class TranscriptionBulk(CreateAPIView):
     serializer_class = TranscriptionsSerializer
     permission_classes = (IsVerified, )
 
+    def get_object(self):  # Ignore the get_object part in UpdateAPIView.update
+        return
+
     def perform_create(self, serializer):
+        self.run(serializer, delete=False)
+
+    def perform_update(self, serializer):
+        self.run(serializer, delete=True)
 
+    def run(self, serializer, delete=False):
         parent = serializer.validated_data['parent']
         source = DataSource.from_ml_tool(serializer.validated_data['recognizer'])
         trpolygons = build_transcriptions(
@@ -96,8 +104,10 @@ class TranscriptionBulk(CreateAPIView):
         )
         if not trpolygons:
             return
-        # Clear all previous transcriptions from source
-        parent.transcriptions.filter(source=source).delete()
+
+        if delete:
+            # Clear all previous transcriptions from source
+            parent.transcriptions.filter(source=source).delete()
 
         transcriptions, _ = save_transcriptions(*trpolygons)
 
diff --git a/arkindex/documents/tests/test_transcription_create.py b/arkindex/documents/tests/test_transcription_create.py
index ced37c3043..9a210f42fb 100644
--- a/arkindex/documents/tests/test_transcription_create.py
+++ b/arkindex/documents/tests/test_transcription_create.py
@@ -192,3 +192,70 @@ class TestTranscriptionCreate(FixtureAPITestCase):
             ]
         })
         self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+
+    @patch('arkindex.project.serializer_fields.MLTool.get')
+    @patch('arkindex.images.importer.Indexer')
+    def test_update_bulk_transcription(self, indexer, ml_get_mock):
+        ml_get_mock.return_value.type = self.src.type
+        ml_get_mock.return_value.slug = self.src.slug
+        ml_get_mock.return_value.version = self.src.revision
+        self.src.internal = True
+        self.src.save()
+
+        data = {
+            "parent": str(self.page.id),
+            "recognizer": self.src.slug,
+            "transcriptions": [
+                {
+                    "type": "word",
+                    "polygon": [(0, 0), (100, 0), (100, 100), (0, 100), (0, 0)],
+                    "text": "NEKUDOTAYIM",
+                    "score": 0.83,
+                },
+                {
+                    "type": "line",
+                    "polygon": [(0, 0), (200, 0), (200, 200), (0, 200), (0, 0)],
+                    "text": "This is a test",
+                    "score": 0.75,
+                },
+            ]
+        }
+
+        self.client.force_login(self.user)
+        response = self.client.put(reverse('api:transcription-bulk'), format='json', data=data)
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        new_ts = self.page.transcriptions.get(text="NEKUDOTAYIM", type=TranscriptionType.Word)
+        self.assertEqual(new_ts.zone.polygon, Polygon.from_coords(0, 0, 100, 100))
+        self.assertEqual(new_ts.score, 0.83)
+        self.assertEqual(new_ts.source, self.src)
+
+        page_ts = self.page.transcriptions.get(type=TranscriptionType.Page)
+        self.assertEqual(page_ts.text, 'This is a test')
+        self.assertEqual(page_ts.score, None)
+        self.assertEqual(page_ts.source, self.src)
+        self.assertEqual(page_ts.zone, self.page.zone)
+
+        # Indexer called
+        self.assertEqual(indexer.return_value.run_index.call_count, 2)
+
+        # Update again
+        indexer.reset_mock()
+        data['transcriptions'] = [
+            {
+                "type": "word",
+                "polygon": [(0, 0), (200, 0), (200, 200), (0, 200), (0, 0)],
+                "text": "Hi",
+                "score": 0.99,
+            },
+        ]
+
+        response = self.client.put(reverse('api:transcription-bulk'), format='json', data=data)
+        self.assertEqual(response.status_code, status.HTTP_200_OK)
+        self.assertEqual(indexer.return_value.run_index.call_count, 2)
+
+        # Previous transcriptions should be replaced by a Word and a Page transcription
+        self.assertEqual(self.page.transcriptions.count(), 2)
+        ts = self.page.transcriptions.get(type=TranscriptionType.Word)
+        self.assertEqual(ts.text, "Hi")
+        page_ts = self.page.transcriptions.get(type=TranscriptionType.Page)
+        self.assertEqual(page_ts.text, '')
-- 
GitLab