diff --git a/arkindex/documents/api/ml.py b/arkindex/documents/api/ml.py index 259906a0cadbf97ad6e79d5a94f8472385192ea2..f818d89a485288fd4a003d529575e6614cc67e38 100644 --- a/arkindex/documents/api/ml.py +++ b/arkindex/documents/api/ml.py @@ -266,6 +266,7 @@ class ElementTranscriptionsBulk(CreateAPIView): worker_version=worker_version, text=annotation['text'], confidence=annotation['confidence'], + orientation=annotation['orientation'] ) annotation['id'] = transcription.id transcriptions.append(transcription) diff --git a/arkindex/documents/serializers/ml.py b/arkindex/documents/serializers/ml.py index 607f1ad896ce377fd76aad332f0df7924e238f34..a5e6593a78563c56c646040d960009efea066006 100644 --- a/arkindex/documents/serializers/ml.py +++ b/arkindex/documents/serializers/ml.py @@ -328,6 +328,7 @@ class SimpleTranscriptionSerializer(serializers.Serializer): max_value=1, required=False, ) + orientation = EnumField(TextOrientation, default=TextOrientation.HorizontalLeftToRight, required=False) def validate(self, data): if not ('score' in data) ^ ('confidence' in data): diff --git a/arkindex/documents/tests/test_bulk_element_transcriptions.py b/arkindex/documents/tests/test_bulk_element_transcriptions.py index dbb705f96cab73ab998c9203a56552e736336ef8..e1fec25bbac4b4305f5dbec59a6d018ea06d1605 100644 --- a/arkindex/documents/tests/test_bulk_element_transcriptions.py +++ b/arkindex/documents/tests/test_bulk_element_transcriptions.py @@ -6,7 +6,7 @@ from django.urls import reverse from rest_framework import status from arkindex.dataimport.models import WorkerVersion -from arkindex.documents.models import Corpus, Element, Transcription +from arkindex.documents.models import Corpus, Element, TextOrientation, Transcription from arkindex.project.tests import FixtureAPITestCase @@ -39,8 +39,8 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): existing_element_ids = list(Element.objects.get_descending(self.page.id).values_list('id', flat=True)) transcriptions = [ - ([[13, 37], [133, 37], [133, 137], [13, 137], [13, 37]], 'Hello world !', 0.1337), - ([[24, 42], [64, 42], [64, 142], [24, 142], [24, 42]], 'I <3 JavaScript', 0.42), + ([[13, 37], [133, 37], [133, 137], [13, 137], [13, 37]], 'Hello world !', 0.1337, 'vertical-lr'), + ([[24, 42], [64, 42], [64, 142], [24, 142], [24, 42]], 'I <3 JavaScript', 0.42, 'horizontal-lr'), ] data = { 'element_type': 'text_line', @@ -48,8 +48,9 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): 'transcriptions': [{ 'polygon': poly, 'text': text, - 'confidence': confidence - } for poly, text, confidence in transcriptions] + 'confidence': confidence, + 'orientation': orientation + } for poly, text, confidence, orientation in transcriptions] } self.client.force_login(self.user) response = self.client.post( @@ -73,10 +74,10 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): ] ) self.assertCountEqual( - created_elts.values_list('transcriptions__text', 'transcriptions__worker_version'), + created_elts.values_list('transcriptions__text', 'transcriptions__worker_version', 'transcriptions__orientation'), [ - ('Hello world !', self.worker_version.id), - ('I <3 JavaScript', self.worker_version.id) + ('Hello world !', self.worker_version.id, TextOrientation.VerticalLeftToRight), + ('I <3 JavaScript', self.worker_version.id, TextOrientation.HorizontalLeftToRight) ] ) @@ -354,7 +355,7 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): second_elt = created_elts.get(transcriptions__text='stock') first_tr = Transcription.objects.get(text="stuck") second_tr = Transcription.objects.get(text="stock") - thrid_td = Transcription.objects.get(text="stack") + third_td = Transcription.objects.get(text="stack") self.assertListEqual(response.json(), [{ 'id': str(first_tr.id), 'element_id': str(first_elt.id), @@ -364,7 +365,7 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): 'element_id': str(second_elt.id), 'created': True }, { - 'id': str(thrid_td.id), + 'id': str(third_td.id), 'element_id': str(self.line.id), 'created': False }]) @@ -432,3 +433,65 @@ class TestBulkElementTranscriptions(FixtureAPITestCase): ('I <3 JavaScript', self.worker_version.id) ] ) + + def test_bulk_transcriptions_default_orientation(self): + """ + No text orientation set > default (horizontal-lr) value + """ + existing_element_ids = list(Element.objects.get_descending(self.page.id).values_list('id', flat=True)) + transcriptions = [ + ([[13, 37], [133, 37], [133, 137], [13, 137], [13, 37]], 'Hello world !', 0.1337), + ([[24, 42], [64, 42], [64, 142], [24, 142], [24, 42]], 'I <3 JavaScript', 0.42), + ] + data = { + 'element_type': 'text_line', + 'worker_version': str(self.worker_version.id), + 'transcriptions': [{ + 'polygon': poly, + 'text': text, + 'confidence': confidence, + } for poly, text, confidence in transcriptions] + } + self.client.force_login(self.user) + response = self.client.post( + reverse('api:element-transcriptions-bulk', kwargs={'pk': self.page.id}), + format='json', + data=data + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + created_elts = Element.objects.get_descending(self.page.id).exclude(id__in=existing_element_ids) + self.assertCountEqual( + created_elts.values_list('transcriptions__orientation', flat=True), + [TextOrientation.HorizontalLeftToRight, TextOrientation.HorizontalLeftToRight] + ) + + def test_bulk_transcriptions_invalid_orientation(self): + """ + Specifying an invalid text orientation throws an error + """ + transcriptions = [ + ([[13, 37], [133, 37], [133, 137], [13, 137], [13, 37]], 'Hello world !', 0.1337, 'wiggly'), + ([[24, 42], [64, 42], [64, 142], [24, 142], [24, 42]], 'I <3 JavaScript', 0.42, 'timey-wimey'), + ] + data = { + 'element_type': 'text_line', + 'worker_version': str(self.worker_version.id), + 'transcriptions': [{ + 'polygon': poly, + 'text': text, + 'confidence': confidence, + 'orientation': orientation, + } for poly, text, confidence, orientation in transcriptions] + } + self.client.force_login(self.user) + response = self.client.post( + reverse('api:element-transcriptions-bulk', kwargs={'pk': self.page.id}), + format='json', + data=data + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {'transcriptions': [ + {'orientation': ['Value is not of type TextOrientation']}, + {'orientation': ['Value is not of type TextOrientation']} + ]} + )