From 9244ade390e4c1656af50f4f4fbdeea0b6550e27 Mon Sep 17 00:00:00 2001 From: mlbonhomme <bonhomme@teklia.com> Date: Wed, 17 Nov 2021 08:38:12 +0000 Subject: [PATCH] Text orientation in base worker --- .isort.cfg | 2 +- arkindex_worker/cache.py | 1 + arkindex_worker/worker/transcription.py | 55 +- tests/conftest.py | 8 + tests/test_cache.py | 2 +- tests/test_elements_worker/test_entities.py | 2 + .../test_transcriptions.py | 494 +++++++++++++++++- 7 files changed, 553 insertions(+), 11 deletions(-) diff --git a/.isort.cfg b/.isort.cfg index 2fe5c98a..6d30d8e7 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -8,4 +8,4 @@ line_length = 88 default_section=FIRSTPARTY known_first_party = arkindex,arkindex_common -known_third_party =PIL,apistar,gitlab,gnupg,peewee,pytest,requests,setuptools,sh,tenacity,yaml +known_third_party =PIL,apistar,gitlab,gnupg,peewee,playhouse,pytest,requests,setuptools,sh,tenacity,yaml diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 30141046..b184460d 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -124,6 +124,7 @@ class CachedTranscription(Model): element = ForeignKeyField(CachedElement, backref="transcriptions") text = TextField() confidence = FloatField() + orientation = CharField(max_length=50) worker_version_id = UUIDField() class Meta: diff --git a/arkindex_worker/worker/transcription.py b/arkindex_worker/worker/transcription.py index 388f756c..8ff8342e 100644 --- a/arkindex_worker/worker/transcription.py +++ b/arkindex_worker/worker/transcription.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +from enum import Enum + from peewee import IntegrityError from arkindex_worker import logger @@ -7,8 +9,17 @@ from arkindex_worker.cache import CachedElement, CachedTranscription from arkindex_worker.models import Element +class TextOrientation(Enum): + HorizontalLeftToRight = "horizontal-lr" + HorizontalRightToLeft = "horizontal-rl" + VerticalRightToLeft = "vertical-rl" + VerticalLeftToRight = "vertical-lr" + + class TranscriptionMixin(object): - def create_transcription(self, element, text, score): + def create_transcription( + self, element, text, score, orientation=TextOrientation.HorizontalLeftToRight + ): """ Create a transcription on the given element through the API. """ @@ -18,7 +29,9 @@ class TranscriptionMixin(object): assert text and isinstance( text, str ), "text shouldn't be null and should be of type str" - + assert orientation and isinstance( + orientation, TextOrientation + ), "orientation shouldn't be null and should be of type TextOrientation" assert ( isinstance(score, float) and 0 <= score <= 1 ), "score shouldn't be null and should be a float in [0..1] range" @@ -36,6 +49,7 @@ class TranscriptionMixin(object): "text": text, "worker_version": self.worker_version_id, "score": score, + "orientation": orientation.value, }, ) @@ -50,6 +64,7 @@ class TranscriptionMixin(object): "element_id": element.id, "text": created["text"], "confidence": created["confidence"], + "orientation": created["orientation"], "worker_version_id": self.worker_version_id, } ] @@ -70,7 +85,10 @@ class TranscriptionMixin(object): transcriptions, list ), "transcriptions shouldn't be null and should be of type list" - for index, transcription in enumerate(transcriptions): + # Create shallow copies of every transcription to avoid mutating the original payload + transcriptions_payload = list(map(dict, transcriptions)) + + for (index, transcription) in enumerate(transcriptions_payload): element_id = transcription.get("element_id") assert element_id and isinstance( element_id, str @@ -86,11 +104,20 @@ class TranscriptionMixin(object): score is not None and isinstance(score, float) and 0 <= score <= 1 ), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range" + orientation = transcription.get( + "orientation", TextOrientation.HorizontalLeftToRight + ) + assert orientation and isinstance( + orientation, TextOrientation + ), f"Transcription at index {index} in transcriptions: orientation shouldn't be null and should be of type TextOrientation" + if orientation: + transcription["orientation"] = orientation.value + created_trs = self.request( "CreateTranscriptions", body={ "worker_version": self.worker_version_id, - "transcriptions": transcriptions, + "transcriptions": transcriptions_payload, }, )["transcriptions"] @@ -106,6 +133,7 @@ class TranscriptionMixin(object): "element_id": created_tr["element_id"], "text": created_tr["text"], "confidence": created_tr["confidence"], + "orientation": created_tr["orientation"], "worker_version_id": self.worker_version_id, } for created_tr in created_trs @@ -132,7 +160,10 @@ class TranscriptionMixin(object): transcriptions, list ), "transcriptions shouldn't be null and should be of type list" - for index, transcription in enumerate(transcriptions): + # Create shallow copies of every transcription to avoid mutating the original payload + transcriptions_payload = list(map(dict, transcriptions)) + + for (index, transcription) in enumerate(transcriptions_payload): text = transcription.get("text") assert text and isinstance( text, str @@ -143,6 +174,15 @@ class TranscriptionMixin(object): score is not None and isinstance(score, float) and 0 <= score <= 1 ), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range" + orientation = transcription.get( + "orientation", TextOrientation.HorizontalLeftToRight + ) + assert orientation and isinstance( + orientation, TextOrientation + ), f"Transcription at index {index} in transcriptions: orientation shouldn't be null and should be of type TextOrientation" + if orientation: + transcription["orientation"] = orientation.value + polygon = transcription.get("polygon") assert polygon and isinstance( polygon, list @@ -168,7 +208,7 @@ class TranscriptionMixin(object): body={ "element_type": sub_element_type, "worker_version": self.worker_version_id, - "transcriptions": transcriptions, + "transcriptions": transcriptions_payload, "return_elements": True, }, ) @@ -216,6 +256,9 @@ class TranscriptionMixin(object): "element_id": annotation["element_id"], "text": transcription["text"], "confidence": transcription["score"], + "orientation": transcription.get( + "orientation", TextOrientation.HorizontalLeftToRight + ).value, "worker_version_id": self.worker_version_id, } ) diff --git a/tests/conftest.py b/tests/conftest.py index 5a4596c6..aa7e521f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ from arkindex.mock import MockApiClient from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription from arkindex_worker.git import GitHelper, GitlabHelper from arkindex_worker.worker import BaseWorker, ElementsWorker +from arkindex_worker.worker.transcription import TextOrientation FIXTURES_DIR = Path(__file__).resolve().parent / "data" @@ -381,6 +382,7 @@ def mock_cached_transcriptions(): element_id=UUID("11111111-1111-1111-1111-111111111111"), text="This", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), ) CachedTranscription.create( @@ -388,6 +390,7 @@ def mock_cached_transcriptions(): element_id=UUID("22222222-2222-2222-2222-222222222222"), text="is", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), ) CachedTranscription.create( @@ -395,6 +398,7 @@ def mock_cached_transcriptions(): element_id=UUID("33333333-3333-3333-3333-333333333333"), text="a", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), ) CachedTranscription.create( @@ -402,6 +406,7 @@ def mock_cached_transcriptions(): element_id=UUID("44444444-4444-4444-4444-444444444444"), text="good", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), ) CachedTranscription.create( @@ -409,6 +414,7 @@ def mock_cached_transcriptions(): element_id=UUID("55555555-5555-5555-5555-555555555555"), text="test", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), ) @@ -461,6 +467,7 @@ def mock_databases(tmpdir): element_id=UUID("42424242-4242-4242-4242-424242424242"), text="Hello!", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), ) @@ -477,6 +484,7 @@ def mock_databases(tmpdir): element_id=UUID("42424242-4242-4242-4242-424242424242"), text="Hello again neighbor !", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), ) diff --git a/tests/test_cache.py b/tests/test_cache.py index 121fa580..03a68144 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -59,7 +59,7 @@ CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) -CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" +CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" actual_schema = "\n".join( [ diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py index 8b16ba15..e4f43b6f 100644 --- a/tests/test_elements_worker/test_entities.py +++ b/tests/test_elements_worker/test_entities.py @@ -13,6 +13,7 @@ from arkindex_worker.cache import ( ) from arkindex_worker.models import Element from arkindex_worker.worker import EntityType +from arkindex_worker.worker.transcription import TextOrientation from . import BASE_API_CALLS @@ -465,6 +466,7 @@ def test_create_transcription_entity_with_cache( element=UUID("12341234-1234-1234-1234-123412341234"), text="Hello, it's me.", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ) CachedEntity.create( diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py index d7166b54..4a54f532 100644 --- a/tests/test_elements_worker/test_transcriptions.py +++ b/tests/test_elements_worker/test_transcriptions.py @@ -4,9 +4,11 @@ from uuid import UUID import pytest from apistar.exceptions import ErrorResponse +from playhouse.shortcuts import model_to_dict from arkindex_worker.cache import CachedElement, CachedTranscription from arkindex_worker.models import Element +from arkindex_worker.worker.transcription import TextOrientation from . import BASE_API_CALLS @@ -117,6 +119,74 @@ def test_create_transcription_wrong_score(mock_elements_worker): ) +def test_create_transcription_default_orientation(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, + f"http://testserver/api/v1/element/{elt.id}/transcription/", + status=200, + json={ + "id": "56785678-5678-5678-5678-567856785678", + "text": "Animula vagula blandula", + "confidence": 0.42, + "worker_version_id": "12341234-1234-1234-1234-123412341234", + }, + ) + mock_elements_worker.create_transcription( + element=elt, + text="Animula vagula blandula", + score=0.42, + ) + assert json.loads(responses.calls[-1].request.body) == { + "text": "Animula vagula blandula", + "worker_version": "12341234-1234-1234-1234-123412341234", + "score": 0.42, + "orientation": "horizontal-lr", + } + + +def test_create_transcription_orientation(responses, mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, + f"http://testserver/api/v1/element/{elt.id}/transcription/", + status=200, + json={ + "id": "56785678-5678-5678-5678-567856785678", + "text": "Animula vagula blandula", + "confidence": 0.42, + "worker_version_id": "12341234-1234-1234-1234-123412341234", + }, + ) + mock_elements_worker.create_transcription( + element=elt, + text="Animula vagula blandula", + orientation=TextOrientation.VerticalLeftToRight, + score=0.42, + ) + assert json.loads(responses.calls[-1].request.body) == { + "text": "Animula vagula blandula", + "worker_version": "12341234-1234-1234-1234-123412341234", + "score": 0.42, + "orientation": "vertical-lr", + } + + +def test_create_transcription_wrong_orientation(mock_elements_worker): + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcription( + element=elt, + text="Animula vagula blandula", + score=0.26, + orientation="eliptical", + ) + assert ( + str(e.value) + == "orientation shouldn't be null and should be of type TextOrientation" + ) + + def test_create_transcription_api_error(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( @@ -177,6 +247,7 @@ def test_create_transcription(responses, mock_elements_worker): "text": "i am a line", "worker_version": "12341234-1234-1234-1234-123412341234", "score": 0.42, + "orientation": "horizontal-lr", } @@ -192,6 +263,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca "text": "i am a line", "score": 0.42, "confidence": 0.42, + "orientation": "horizontal-lr", "worker_version_id": "12341234-1234-1234-1234-123412341234", }, ) @@ -212,6 +284,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca assert json.loads(responses.calls[-1].request.body) == { "text": "i am a line", "worker_version": "12341234-1234-1234-1234-123412341234", + "orientation": "horizontal-lr", "score": 0.42, } @@ -222,11 +295,63 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca element_id=UUID(elt.id), text="i am a line", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ) ] +def test_create_transcription_orientation_with_cache( + responses, mock_elements_worker_with_cache +): + elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing") + responses.add( + responses.POST, + f"http://testserver/api/v1/element/{elt.id}/transcription/", + status=200, + json={ + "id": "56785678-5678-5678-5678-567856785678", + "text": "Animula vagula blandula", + "confidence": 0.42, + "orientation": "vertical-lr", + "worker_version_id": "12341234-1234-1234-1234-123412341234", + }, + ) + mock_elements_worker_with_cache.create_transcription( + element=elt, + text="Animula vagula blandula", + orientation=TextOrientation.VerticalLeftToRight, + score=0.42, + ) + assert json.loads(responses.calls[-1].request.body) == { + "text": "Animula vagula blandula", + "worker_version": "12341234-1234-1234-1234-123412341234", + "orientation": "vertical-lr", + "score": 0.42, + } + # Check that the text orientation was properly stored in SQLite cache + assert list(map(model_to_dict, CachedTranscription.select())) == [ + { + "id": UUID("56785678-5678-5678-5678-567856785678"), + "element": { + "id": UUID("12341234-1234-1234-1234-123412341234"), + "parent_id": None, + "type": "thing", + "image": None, + "polygon": None, + "rotation_angle": 0, + "mirrored": False, + "initial": False, + "worker_version_id": None, + }, + "text": "Animula vagula blandula", + "confidence": 0.42, + "orientation": TextOrientation.VerticalLeftToRight.value, + "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), + } + ] + + def test_create_transcriptions_wrong_transcriptions(mock_elements_worker): with pytest.raises(AssertionError) as e: mock_elements_worker.create_transcriptions( @@ -457,6 +582,27 @@ def test_create_transcriptions_wrong_transcriptions(mock_elements_worker): == "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range" ) + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_transcriptions( + transcriptions=[ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "word", + "score": 0.28, + "orientation": "wobble", + }, + ], + ) + assert ( + str(e.value) + == "Transcription at index 1 in transcriptions: orientation shouldn't be null and should be of type TextOrientation" + ) + def test_create_transcriptions_api_error(responses, mock_elements_worker): responses.add( @@ -519,12 +665,14 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache): "id": "00000000-0000-0000-0000-000000000000", "element_id": "11111111-1111-1111-1111-111111111111", "text": "The", + "orientation": "horizontal-lr", "confidence": 0.75, }, { "id": "11111111-1111-1111-1111-111111111111", "element_id": "11111111-1111-1111-1111-111111111111", "text": "word", + "orientation": "horizontal-lr", "confidence": 0.42, }, ], @@ -544,7 +692,20 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache): assert json.loads(responses.calls[-1].request.body) == { "worker_version": "12341234-1234-1234-1234-123412341234", - "transcriptions": trans, + "transcriptions": [ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "The", + "score": 0.75, + "orientation": TextOrientation.HorizontalLeftToRight.value, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "word", + "score": 0.42, + "orientation": TextOrientation.HorizontalLeftToRight.value, + }, + ], } # Check that created transcriptions were properly stored in SQLite cache @@ -554,6 +715,7 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache): element_id=UUID("11111111-1111-1111-1111-111111111111"), text="The", confidence=0.75, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ), CachedTranscription( @@ -561,11 +723,117 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache): element_id=UUID("11111111-1111-1111-1111-111111111111"), text="word", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ), ] +def test_create_transcriptions_orientation(responses, mock_elements_worker_with_cache): + CachedElement.create(id="11111111-1111-1111-1111-111111111111", type="thing") + trans = [ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "Animula vagula blandula", + "score": 0.12, + "orientation": TextOrientation.HorizontalRightToLeft, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "Hospes comesque corporis", + "score": 0.21, + "orientation": TextOrientation.VerticalLeftToRight, + }, + ] + + responses.add( + responses.POST, + "http://testserver/api/v1/transcription/bulk/", + status=200, + json={ + "worker_version": "12341234-1234-1234-1234-123412341234", + "transcriptions": [ + { + "id": "00000000-0000-0000-0000-000000000000", + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "Animula vagula blandula", + "orientation": "horizontal-rl", + "confidence": 0.12, + }, + { + "id": "11111111-1111-1111-1111-111111111111", + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "Hospes comesque corporis", + "orientation": "vertical-lr", + "confidence": 0.21, + }, + ], + }, + ) + + mock_elements_worker_with_cache.create_transcriptions( + transcriptions=trans, + ) + + assert json.loads(responses.calls[-1].request.body) == { + "worker_version": "12341234-1234-1234-1234-123412341234", + "transcriptions": [ + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "Animula vagula blandula", + "score": 0.12, + "orientation": TextOrientation.HorizontalRightToLeft.value, + }, + { + "element_id": "11111111-1111-1111-1111-111111111111", + "text": "Hospes comesque corporis", + "score": 0.21, + "orientation": TextOrientation.VerticalLeftToRight.value, + }, + ], + } + + # Check that oriented transcriptions were properly stored in SQLite cache + assert list(map(model_to_dict, CachedTranscription.select())) == [ + { + "id": UUID("00000000-0000-0000-0000-000000000000"), + "element": { + "id": UUID("11111111-1111-1111-1111-111111111111"), + "parent_id": None, + "type": "thing", + "image": None, + "polygon": None, + "rotation_angle": 0, + "mirrored": False, + "initial": False, + "worker_version_id": None, + }, + "text": "Animula vagula blandula", + "confidence": 0.12, + "orientation": TextOrientation.HorizontalRightToLeft.value, + "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), + }, + { + "id": UUID("11111111-1111-1111-1111-111111111111"), + "element": { + "id": UUID("11111111-1111-1111-1111-111111111111"), + "parent_id": None, + "type": "thing", + "image": None, + "polygon": None, + "rotation_angle": 0, + "mirrored": False, + "initial": False, + "worker_version_id": None, + }, + "text": "Hospes comesque corporis", + "confidence": 0.21, + "orientation": TextOrientation.VerticalLeftToRight.value, + "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), + }, + ] + + def test_create_element_transcriptions_wrong_element(mock_elements_worker): with pytest.raises(AssertionError) as e: mock_elements_worker.create_element_transcriptions( @@ -941,6 +1209,29 @@ def test_create_element_transcriptions_wrong_transcriptions(mock_elements_worker == "Transcription at index 1 in transcriptions: polygon points should be lists of two numbers" ) + with pytest.raises(AssertionError) as e: + mock_elements_worker.create_element_transcriptions( + element=elt, + sub_element_type="page", + transcriptions=[ + { + "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], + "score": 0.75, + "text": "The", + }, + { + "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], + "score": 0.35, + "text": "word", + "orientation": "uptown", + }, + ], + ) + assert ( + str(e.value) + == "Transcription at index 1 in transcriptions: orientation shouldn't be null and should be of type TextOrientation" + ) + def test_create_element_transcriptions_api_error(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) @@ -1011,7 +1302,26 @@ def test_create_element_transcriptions(responses, mock_elements_worker): assert json.loads(responses.calls[-1].request.body) == { "element_type": "page", "worker_version": "12341234-1234-1234-1234-123412341234", - "transcriptions": TRANSCRIPTIONS_SAMPLE, + "transcriptions": [ + { + "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], + "score": 0.5, + "text": "The", + "orientation": TextOrientation.HorizontalLeftToRight.value, + }, + { + "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], + "score": 0.75, + "text": "first", + "orientation": TextOrientation.HorizontalLeftToRight.value, + }, + { + "polygon": [[1000, 300], [1200, 300], [1200, 500], [1000, 500]], + "score": 0.9, + "text": "line", + "orientation": TextOrientation.HorizontalLeftToRight.value, + }, + ], "return_elements": True, } assert annotations == [ @@ -1077,7 +1387,26 @@ def test_create_element_transcriptions_with_cache( assert json.loads(responses.calls[-1].request.body) == { "element_type": "page", "worker_version": "12341234-1234-1234-1234-123412341234", - "transcriptions": TRANSCRIPTIONS_SAMPLE, + "transcriptions": [ + { + "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], + "score": 0.5, + "text": "The", + "orientation": TextOrientation.HorizontalLeftToRight.value, + }, + { + "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], + "score": 0.75, + "text": "first", + "orientation": TextOrientation.HorizontalLeftToRight.value, + }, + { + "polygon": [[1000, 300], [1200, 300], [1200, 500], [1000, 500]], + "score": 0.9, + "text": "line", + "orientation": TextOrientation.HorizontalLeftToRight.value, + }, + ], "return_elements": True, } assert annotations == [ @@ -1121,6 +1450,7 @@ def test_create_element_transcriptions_with_cache( element_id=UUID("11111111-1111-1111-1111-111111111111"), text="The", confidence=0.5, + orientation=TextOrientation.HorizontalLeftToRight.value, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ), CachedTranscription( @@ -1128,6 +1458,7 @@ def test_create_element_transcriptions_with_cache( element_id=UUID("22222222-2222-2222-2222-222222222222"), text="first", confidence=0.75, + orientation=TextOrientation.HorizontalLeftToRight.value, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ), CachedTranscription( @@ -1135,11 +1466,168 @@ def test_create_element_transcriptions_with_cache( element_id=UUID("11111111-1111-1111-1111-111111111111"), text="line", confidence=0.9, + orientation=TextOrientation.HorizontalLeftToRight.value, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ), ] +def test_create_transcriptions_orientation_with_cache( + responses, mock_elements_worker_with_cache +): + elt = CachedElement(id="12341234-1234-1234-1234-123412341234", type="thing") + + responses.add( + responses.POST, + f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/", + status=200, + json=[ + { + "id": "56785678-5678-5678-5678-567856785678", + "element_id": "11111111-1111-1111-1111-111111111111", + "created": True, + }, + { + "id": "67896789-6789-6789-6789-678967896789", + "element_id": "22222222-2222-2222-2222-222222222222", + "created": False, + }, + { + "id": "78907890-7890-7890-7890-789078907890", + "element_id": "11111111-1111-1111-1111-111111111111", + "created": True, + }, + ], + ) + + oriented_transcriptions = [ + { + "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], + "score": 0.5, + "text": "Animula vagula blandula", + }, + { + "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], + "score": 0.75, + "text": "Hospes comesque corporis", + "orientation": TextOrientation.VerticalLeftToRight, + }, + { + "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], + "score": 0.9, + "text": "Quae nunc abibis in loca", + "orientation": TextOrientation.HorizontalRightToLeft, + }, + ] + + annotations = mock_elements_worker_with_cache.create_element_transcriptions( + element=elt, + sub_element_type="page", + transcriptions=oriented_transcriptions, + ) + + assert json.loads(responses.calls[-1].request.body) == { + "element_type": "page", + "worker_version": "12341234-1234-1234-1234-123412341234", + "transcriptions": [ + { + "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], + "score": 0.5, + "text": "Animula vagula blandula", + "orientation": TextOrientation.HorizontalLeftToRight.value, + }, + { + "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], + "score": 0.75, + "text": "Hospes comesque corporis", + "orientation": TextOrientation.VerticalLeftToRight.value, + }, + { + "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], + "score": 0.9, + "text": "Quae nunc abibis in loca", + "orientation": TextOrientation.HorizontalRightToLeft.value, + }, + ], + "return_elements": True, + } + assert annotations == [ + { + "id": "56785678-5678-5678-5678-567856785678", + "element_id": "11111111-1111-1111-1111-111111111111", + "created": True, + }, + { + "id": "67896789-6789-6789-6789-678967896789", + "element_id": "22222222-2222-2222-2222-222222222222", + "created": False, + }, + { + "id": "78907890-7890-7890-7890-789078907890", + "element_id": "11111111-1111-1111-1111-111111111111", + "created": True, + }, + ] + + # Check that the text orientation was properly stored in SQLite cache + assert list(map(model_to_dict, CachedTranscription.select())) == [ + { + "id": UUID("56785678-5678-5678-5678-567856785678"), + "element": { + "id": UUID("11111111-1111-1111-1111-111111111111"), + "parent_id": UUID(elt.id), + "type": "page", + "image": None, + "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], + "rotation_angle": 0, + "mirrored": False, + "initial": False, + "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), + }, + "text": "Animula vagula blandula", + "confidence": 0.5, + "orientation": TextOrientation.HorizontalLeftToRight.value, + "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), + }, + { + "id": UUID("67896789-6789-6789-6789-678967896789"), + "element": { + "id": UUID("22222222-2222-2222-2222-222222222222"), + "parent_id": UUID(elt.id), + "type": "page", + "image": None, + "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]], + "rotation_angle": 0, + "mirrored": False, + "initial": False, + "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), + }, + "text": "Hospes comesque corporis", + "confidence": 0.75, + "orientation": TextOrientation.VerticalLeftToRight.value, + "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), + }, + { + "id": UUID("78907890-7890-7890-7890-789078907890"), + "element": { + "id": UUID("11111111-1111-1111-1111-111111111111"), + "parent_id": UUID(elt.id), + "type": "page", + "image": None, + "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], + "rotation_angle": 0, + "mirrored": False, + "initial": False, + "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), + }, + "text": "Quae nunc abibis in loca", + "confidence": 0.9, + "orientation": TextOrientation.HorizontalRightToLeft.value, + "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"), + }, + ] + + def test_list_transcriptions_wrong_element(mock_elements_worker): with pytest.raises(AssertionError) as e: mock_elements_worker.list_transcriptions(element=None) -- GitLab