From c97e6df30c459adbd6f5052c8ef7201a7dd2f4a2 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Thu, 11 Apr 2019 15:17:11 +0000 Subject: [PATCH] Append transcriptions in bulk --- arkindex/documents/api/ml.py | 18 +++-- .../tests/test_transcription_create.py | 67 +++++++++++++++++++ 2 files changed, 81 insertions(+), 4 deletions(-) diff --git a/arkindex/documents/api/ml.py b/arkindex/documents/api/ml.py index 2377f00a3c..e824406cf1 100644 --- a/arkindex/documents/api/ml.py +++ b/arkindex/documents/api/ml.py @@ -1,6 +1,6 @@ from django.conf import settings from rest_framework import status -from rest_framework.generics import CreateAPIView +from rest_framework.generics import CreateAPIView, UpdateAPIView from rest_framework.exceptions import ValidationError from rest_framework.response import Response from arkindex.documents.models import Classification, DataSource, Transcription, TranscriptionType @@ -68,7 +68,7 @@ class TranscriptionCreate(CreateAPIView): return Response({'id': obj.id}, status=status.HTTP_201_CREATED, headers=headers) -class TranscriptionBulk(CreateAPIView): +class TranscriptionBulk(CreateAPIView, UpdateAPIView): ''' Create transcriptions in bulk, all linked to the same image and parent element @@ -76,8 +76,16 @@ class TranscriptionBulk(CreateAPIView): serializer_class = TranscriptionsSerializer permission_classes = (IsVerified, ) + def get_object(self): # Ignore the get_object part in UpdateAPIView.update + return + def perform_create(self, serializer): + self.run(serializer, delete=False) + + def perform_update(self, serializer): + self.run(serializer, delete=True) + def run(self, serializer, delete=False): parent = serializer.validated_data['parent'] source = DataSource.from_ml_tool(serializer.validated_data['recognizer']) trpolygons = build_transcriptions( @@ -96,8 +104,10 @@ class TranscriptionBulk(CreateAPIView): ) if not trpolygons: return - # Clear all previous transcriptions from source - parent.transcriptions.filter(source=source).delete() + + if delete: + # Clear all previous transcriptions from source + parent.transcriptions.filter(source=source).delete() transcriptions, _ = save_transcriptions(*trpolygons) diff --git a/arkindex/documents/tests/test_transcription_create.py b/arkindex/documents/tests/test_transcription_create.py index ced37c3043..9a210f42fb 100644 --- a/arkindex/documents/tests/test_transcription_create.py +++ b/arkindex/documents/tests/test_transcription_create.py @@ -192,3 +192,70 @@ class TestTranscriptionCreate(FixtureAPITestCase): ] }) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @patch('arkindex.project.serializer_fields.MLTool.get') + @patch('arkindex.images.importer.Indexer') + def test_update_bulk_transcription(self, indexer, ml_get_mock): + ml_get_mock.return_value.type = self.src.type + ml_get_mock.return_value.slug = self.src.slug + ml_get_mock.return_value.version = self.src.revision + self.src.internal = True + self.src.save() + + data = { + "parent": str(self.page.id), + "recognizer": self.src.slug, + "transcriptions": [ + { + "type": "word", + "polygon": [(0, 0), (100, 0), (100, 100), (0, 100), (0, 0)], + "text": "NEKUDOTAYIM", + "score": 0.83, + }, + { + "type": "line", + "polygon": [(0, 0), (200, 0), (200, 200), (0, 200), (0, 0)], + "text": "This is a test", + "score": 0.75, + }, + ] + } + + self.client.force_login(self.user) + response = self.client.put(reverse('api:transcription-bulk'), format='json', data=data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + new_ts = self.page.transcriptions.get(text="NEKUDOTAYIM", type=TranscriptionType.Word) + self.assertEqual(new_ts.zone.polygon, Polygon.from_coords(0, 0, 100, 100)) + self.assertEqual(new_ts.score, 0.83) + self.assertEqual(new_ts.source, self.src) + + page_ts = self.page.transcriptions.get(type=TranscriptionType.Page) + self.assertEqual(page_ts.text, 'This is a test') + self.assertEqual(page_ts.score, None) + self.assertEqual(page_ts.source, self.src) + self.assertEqual(page_ts.zone, self.page.zone) + + # Indexer called + self.assertEqual(indexer.return_value.run_index.call_count, 2) + + # Update again + indexer.reset_mock() + data['transcriptions'] = [ + { + "type": "word", + "polygon": [(0, 0), (200, 0), (200, 200), (0, 200), (0, 0)], + "text": "Hi", + "score": 0.99, + }, + ] + + response = self.client.put(reverse('api:transcription-bulk'), format='json', data=data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(indexer.return_value.run_index.call_count, 2) + + # Previous transcriptions should be replaced by a Word and a Page transcription + self.assertEqual(self.page.transcriptions.count(), 2) + ts = self.page.transcriptions.get(type=TranscriptionType.Word) + self.assertEqual(ts.text, "Hi") + page_ts = self.page.transcriptions.get(type=TranscriptionType.Page) + self.assertEqual(page_ts.text, '') -- GitLab