From c51e18ef522f8a9d6037809882b591743c3433d4 Mon Sep 17 00:00:00 2001 From: Erwan Rouchet <rouchet@teklia.com> Date: Mon, 26 Oct 2020 15:30:45 +0000 Subject: [PATCH] CreateTranscriptions endpoint --- arkindex/documents/api/ml.py | 26 ++--- arkindex/documents/serializers/ml.py | 50 +++++++++ .../tests/test_bulk_transcriptions.py | 106 ++++++++++++++++++ arkindex/project/openapi/patch.yml | 8 -- 4 files changed, 167 insertions(+), 23 deletions(-) create mode 100644 arkindex/documents/tests/test_bulk_transcriptions.py diff --git a/arkindex/documents/api/ml.py b/arkindex/documents/api/ml.py index 68dff4af39..fd3820662c 100644 --- a/arkindex/documents/api/ml.py +++ b/arkindex/documents/api/ml.py @@ -3,8 +3,8 @@ from django.db import transaction from django.db.models import Q, Count from rest_framework import status from rest_framework.generics import ( - GenericAPIView, ListAPIView, ListCreateAPIView, - CreateAPIView, UpdateAPIView, RetrieveDestroyAPIView, RetrieveUpdateDestroyAPIView + GenericAPIView, ListAPIView, ListCreateAPIView, CreateAPIView, + RetrieveDestroyAPIView, RetrieveUpdateDestroyAPIView ) from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.response import Response @@ -13,9 +13,10 @@ from arkindex.documents.models import \ from arkindex_common.ml_tool import MLToolType from arkindex.documents.serializers.ml import ( ClassificationsSerializer, ClassificationCreateSerializer, ClassificationSerializer, - TranscriptionSerializer, TranscriptionCreateSerializer, ElementTranscriptionsBulkSerializer, - DataSourceStatsSerializer, ClassificationsSelectionSerializer, ClassificationMode, - CountMLClassSerializer, AnnotatedElementSerializer + ClassificationsSelectionSerializer, ClassificationMode, + TranscriptionSerializer, TranscriptionCreateSerializer, TranscriptionBulkSerializer, + ElementTranscriptionsBulkSerializer, AnnotatedElementSerializer, + DataSourceStatsSerializer, CountMLClassSerializer ) from arkindex.images.models import Zone from arkindex.project.filters import SafeSearchFilter @@ -292,21 +293,16 @@ class ElementTranscriptionsBulk(CreateAPIView): return annotations -class TranscriptionBulk(DeprecatedMixin, CreateAPIView, UpdateAPIView): +class TranscriptionBulk(CreateAPIView): ''' - Create multiple transcriptions at once, all linked to the same page - and to the same recognizer. + Create multiple transcriptions at once on existing elements ''' - # Force DRF to ignore PATCH - http_method_names = ['post', 'put', 'head', 'options', 'trace'] openapi_overrides = { + 'operationId': 'CreateTranscriptions', 'tags': ['transcriptions'], } - deprecation_message = ( - 'Creating or updating transcriptions with their own zones is now deprecated. ' - 'Please use CreateElementTranscriptions to push transcriptions in bulk ' - 'attached to sub-elements.' - ) + permission_classes = (IsVerified, ) + serializer_class = TranscriptionBulkSerializer class CorpusMLClassList(CorpusACLMixin, ListCreateAPIView): diff --git a/arkindex/documents/serializers/ml.py b/arkindex/documents/serializers/ml.py index 40ef2f086a..7fa83582f1 100644 --- a/arkindex/documents/serializers/ml.py +++ b/arkindex/documents/serializers/ml.py @@ -369,6 +369,56 @@ class AnnotatedElementSerializer(serializers.Serializer): created = serializers.BooleanField(default=False) +class TranscriptionBulkItemSerializer(serializers.Serializer): + # Element retrieval and checks is done in the BulkSerializer to avoid duplicate queries + element_id = serializers.UUIDField( + help_text='ID of an existing element to add the transcription to' + ) + type = EnumField(TranscriptionType, help_text='Type of the transcription') + text = serializers.CharField() + score = serializers.FloatField(min_value=0, max_value=1) + + +class TranscriptionBulkSerializer(serializers.Serializer): + worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all()) + transcriptions = TranscriptionBulkItemSerializer(many=True) + + def validate(self, data): + element_ids = set(transcription['element_id'] for transcription in data['transcriptions']) + found_ids = set(Element.objects.filter( + id__in=element_ids, + corpus__in=Corpus.objects.writable(self.context['request'].user) + ).values_list('id', flat=True)) + + missing_ids = element_ids - found_ids + if not missing_ids: + return data + + # Return an error message as a list just like DRF's ListField, for easier debugging + raise ValidationError({'transcriptions': [ + {"element_id": [f'Element {transcription["element_id"]} was not found or cannot be written to.']} + if transcription['element_id'] in missing_ids + else {} + for i, transcription in enumerate(data['transcriptions']) + ]}) + + def create(self, validated_data): + transcriptions = [ + Transcription( + worker_version=validated_data['worker_version'], + element_id=transcription['element_id'], + type=transcription['type'], + text=transcription['text'], + score=transcription['score'], + ) + for transcription in validated_data['transcriptions'] + ] + Transcription.objects.bulk_create(transcriptions) + + validated_data['transcriptions'] = transcriptions + return validated_data + + class ClassificationBulkSerializer(serializers.Serializer): """ Single classification serializer for bulk insertion diff --git a/arkindex/documents/tests/test_bulk_transcriptions.py b/arkindex/documents/tests/test_bulk_transcriptions.py new file mode 100644 index 0000000000..12cb09e2b5 --- /dev/null +++ b/arkindex/documents/tests/test_bulk_transcriptions.py @@ -0,0 +1,106 @@ +from django.urls import reverse +from rest_framework import status +from arkindex_common.enums import TranscriptionType +from arkindex.project.tests import FixtureAPITestCase +from arkindex.dataimport.models import WorkerVersion + + +class TestBulkTranscriptions(FixtureAPITestCase): + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + cls.worker_version = WorkerVersion.objects.get(worker__slug='reco') + + def test_bulk_transcriptions_requires_login(self): + with self.assertNumQueries(0): + response = self.client.post(reverse('api:transcription-bulk')) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_bulk_transcriptions_not_found(self): + self.client.force_login(self.user) + self.user.corpus_right.all().delete() + forbidden_element = self.corpus.elements.get(name='Volume 1, page 1r') + with self.assertNumQueries(4): + response = self.client.post(reverse('api:transcription-bulk'), { + "worker_version": str(self.worker_version.id), + "transcriptions": [ + { + "element_id": str(forbidden_element.id), + "type": TranscriptionType.Word.value, + "text": "lol", + "score": 0.4 + }, + { + "element_id": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + "type": TranscriptionType.Word.value, + "text": "lol", + "score": 0.4 + } + ], + }, format='json') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), { + 'transcriptions': [ + {'element_id': [f'Element {forbidden_element.id} was not found or cannot be written to.']}, + {'element_id': ['Element aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa was not found or cannot be written to.']}, + ] + }) + + def test_bulk_transcriptions(self): + self.client.force_login(self.user) + + element1 = self.corpus.elements.get(name='Volume 2') + element2 = self.corpus.elements.get(name='Volume 2, page 1r') + self.assertFalse(element1.transcriptions.exists()) + self.assertFalse(element2.transcriptions.exists()) + + with self.assertNumQueries(5): + response = self.client.post(reverse('api:transcription-bulk'), { + "worker_version": str(self.worker_version.id), + "transcriptions": [ + { + "element_id": str(element1.id), + "type": TranscriptionType.Word.value, + "text": "Sneasel", + "score": 0.54 + }, + { + "element_id": str(element2.id), + "type": TranscriptionType.Line.value, + "text": "Charizard", + "score": 0.85 + }, + { + "element_id": str(element1.id), + "type": TranscriptionType.Word.value, + "text": "Raticate", + "score": 0.12 + }, + ], + }, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + self.assertCountEqual( + list(element1.transcriptions.values('type', 'text', 'score')), + [ + { + "type": TranscriptionType.Word, + "text": "Sneasel", + "score": 0.54 + }, + { + "type": TranscriptionType.Word, + "text": "Raticate", + "score": 0.12 + }, + ] + ) + self.assertCountEqual( + list(element2.transcriptions.values('type', 'text', 'score')), + [{ + "type": TranscriptionType.Line, + "text": "Charizard", + "score": 0.85 + }] + ) diff --git a/arkindex/project/openapi/patch.yml b/arkindex/project/openapi/patch.yml index 811ed843d3..66cc7c59da 100644 --- a/arkindex/project/openapi/patch.yml +++ b/arkindex/project/openapi/patch.yml @@ -283,14 +283,6 @@ paths: description: Update the text of a manual transcription delete: description: Delete a manual transcription - /api/v1/transcription/bulk/: - post: - operationId: CreateTranscriptions - put: - operationId: UpdateTranscriptions - description: >- - Replace all existing transcriptions from a given recognizer on a page - with other transcriptions. /api/v1/metadata/{id}/: get: operationId: RetrieveMetaData -- GitLab