From 661eccd361fac03c5aae2000643e3ad6d468bad5 Mon Sep 17 00:00:00 2001 From: mlbonhomme <bonhomme@teklia.com> Date: Tue, 9 Nov 2021 13:21:46 +0100 Subject: [PATCH] support text orientation in base worker --- arkindex_worker/cache.py | 1 + arkindex_worker/worker/transcription.py | 48 +++++++++++++++++-- tests/conftest.py | 6 +++ tests/test_cache.py | 2 +- tests/test_elements_worker/test_entities.py | 2 + .../test_transcriptions.py | 37 +++++++++++++- 6 files changed, 90 insertions(+), 6 deletions(-) diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 30141046..b184460d 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -124,6 +124,7 @@ class CachedTranscription(Model): element = ForeignKeyField(CachedElement, backref="transcriptions") text = TextField() confidence = FloatField() + orientation = CharField(max_length=50) worker_version_id = UUIDField() class Meta: diff --git a/arkindex_worker/worker/transcription.py b/arkindex_worker/worker/transcription.py index 388f756c..49903703 100644 --- a/arkindex_worker/worker/transcription.py +++ b/arkindex_worker/worker/transcription.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +from enum import Enum + from peewee import IntegrityError from arkindex_worker import logger @@ -7,8 +9,17 @@ from arkindex_worker.cache import CachedElement, CachedTranscription from arkindex_worker.models import Element +class TextOrientation(Enum): + HorizontalLeftToRight = "horizontal-lr" + HorizontalRightToLeft = "horizontal-rl" + VerticalRightToLeft = "vertical-rl" + VerticalLeftToRight = "vertical-lr" + + class TranscriptionMixin(object): - def create_transcription(self, element, text, score): + def create_transcription( + self, element, text, score, orientation=TextOrientation.HorizontalLeftToRight + ): """ Create a transcription on the given element through the API. """ @@ -18,7 +29,9 @@ class TranscriptionMixin(object): assert text and isinstance( text, str ), "text shouldn't be null and should be of type str" - + assert orientation and isinstance( + orientation, TextOrientation + ), "orientation shouldn't be null and should be of type TextOrientation" assert ( isinstance(score, float) and 0 <= score <= 1 ), "score shouldn't be null and should be a float in [0..1] range" @@ -36,6 +49,7 @@ class TranscriptionMixin(object): "text": text, "worker_version": self.worker_version_id, "score": score, + "orientation": orientation.value, }, ) @@ -50,6 +64,7 @@ class TranscriptionMixin(object): "element_id": element.id, "text": created["text"], "confidence": created["confidence"], + "orientation": created["orientation"], "worker_version_id": self.worker_version_id, } ] @@ -86,6 +101,13 @@ class TranscriptionMixin(object): score is not None and isinstance(score, float) and 0 <= score <= 1 ), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range" + orientation = transcription.get( + "orientation", TextOrientation.HorizontalLeftToRight + ) + assert orientation and isinstance( + orientation, TextOrientation + ), f"Transcription at index {index} in transcriptions: orientation shouldn't be null and should be of type TextOrientation" + created_trs = self.request( "CreateTranscriptions", body={ @@ -106,6 +128,7 @@ class TranscriptionMixin(object): "element_id": created_tr["element_id"], "text": created_tr["text"], "confidence": created_tr["confidence"], + "orientation": created_tr["orientation"], "worker_version_id": self.worker_version_id, } for created_tr in created_trs @@ -143,6 +166,14 @@ class TranscriptionMixin(object): score is not None and isinstance(score, float) and 0 <= score <= 1 ), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range" + orientation = transcription.get( + "orientation", TextOrientation.HorizontalLeftToRight + ) + assert orientation and isinstance( + orientation, TextOrientation + ), f"Transcription at index {index} in transcriptions: orientation shouldn't be null and should be of type TextOrientation" + transcription["orientation"] = orientation + polygon = transcription.get("polygon") assert polygon and isinstance( polygon, list @@ -162,13 +193,23 @@ class TranscriptionMixin(object): ) return + sent_transcriptions = [ + { + "text": transcription["text"], + "score": transcription["score"], + "orientation": transcription["orientation"].value, + "polygon": transcription["polygon"], + } + for transcription in transcriptions + ] + annotations = self.request( "CreateElementTranscriptions", id=element.id, body={ "element_type": sub_element_type, "worker_version": self.worker_version_id, - "transcriptions": transcriptions, + "transcriptions": sent_transcriptions, "return_elements": True, }, ) @@ -216,6 +257,7 @@ class TranscriptionMixin(object): "element_id": annotation["element_id"], "text": transcription["text"], "confidence": transcription["score"], + "orientation": transcription["orientation"], "worker_version_id": self.worker_version_id, } ) diff --git a/tests/conftest.py b/tests/conftest.py index 5a4596c6..8cdff6f8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ from arkindex.mock import MockApiClient from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription from arkindex_worker.git import GitHelper, GitlabHelper from arkindex_worker.worker import BaseWorker, ElementsWorker +from arkindex_worker.worker.transcription import TextOrientation FIXTURES_DIR = Path(__file__).resolve().parent / "data" @@ -381,6 +382,7 @@ def mock_cached_transcriptions(): element_id=UUID("11111111-1111-1111-1111-111111111111"), text="This", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("56785678-5678-5678-5678-567856785678"), ) CachedTranscription.create( @@ -388,6 +390,7 @@ def mock_cached_transcriptions(): element_id=UUID("22222222-2222-2222-2222-222222222222"), text="is", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), ) CachedTranscription.create( @@ -395,6 +398,7 @@ def mock_cached_transcriptions(): element_id=UUID("33333333-3333-3333-3333-333333333333"), text="a", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), ) CachedTranscription.create( @@ -402,6 +406,7 @@ def mock_cached_transcriptions(): element_id=UUID("44444444-4444-4444-4444-444444444444"), text="good", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), ) CachedTranscription.create( @@ -409,6 +414,7 @@ def mock_cached_transcriptions(): element_id=UUID("55555555-5555-5555-5555-555555555555"), text="test", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("90129012-9012-9012-9012-901290129012"), ) diff --git a/tests/test_cache.py b/tests/test_cache.py index 121fa580..03a68144 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -59,7 +59,7 @@ CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL) CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL) CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id")) -CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" +CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))""" actual_schema = "\n".join( [ diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py index 8b16ba15..e4f43b6f 100644 --- a/tests/test_elements_worker/test_entities.py +++ b/tests/test_elements_worker/test_entities.py @@ -13,6 +13,7 @@ from arkindex_worker.cache import ( ) from arkindex_worker.models import Element from arkindex_worker.worker import EntityType +from arkindex_worker.worker.transcription import TextOrientation from . import BASE_API_CALLS @@ -465,6 +466,7 @@ def test_create_transcription_entity_with_cache( element=UUID("12341234-1234-1234-1234-123412341234"), text="Hello, it's me.", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ) CachedEntity.create( diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py index d7166b54..1b59548d 100644 --- a/tests/test_elements_worker/test_transcriptions.py +++ b/tests/test_elements_worker/test_transcriptions.py @@ -7,6 +7,7 @@ from apistar.exceptions import ErrorResponse from arkindex_worker.cache import CachedElement, CachedTranscription from arkindex_worker.models import Element +from arkindex_worker.worker.transcription import TextOrientation from . import BASE_API_CALLS @@ -177,6 +178,7 @@ def test_create_transcription(responses, mock_elements_worker): "text": "i am a line", "worker_version": "12341234-1234-1234-1234-123412341234", "score": 0.42, + "orientation": "horizontal-lr", } @@ -192,6 +194,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca "text": "i am a line", "score": 0.42, "confidence": 0.42, + "orientation": "horizontal-lr", "worker_version_id": "12341234-1234-1234-1234-123412341234", }, ) @@ -200,6 +203,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca element=elt, text="i am a line", score=0.42, + orientation=TextOrientation.HorizontalLeftToRight, ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 @@ -212,6 +216,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca assert json.loads(responses.calls[-1].request.body) == { "text": "i am a line", "worker_version": "12341234-1234-1234-1234-123412341234", + "orientation": "horizontal-lr", "score": 0.42, } @@ -222,6 +227,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca element_id=UUID(elt.id), text="i am a line", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ) ] @@ -519,12 +525,14 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache): "id": "00000000-0000-0000-0000-000000000000", "element_id": "11111111-1111-1111-1111-111111111111", "text": "The", + "orientation": "horizontal-lr", "confidence": 0.75, }, { "id": "11111111-1111-1111-1111-111111111111", "element_id": "11111111-1111-1111-1111-111111111111", "text": "word", + "orientation": "horizontal-lr", "confidence": 0.42, }, ], @@ -554,6 +562,7 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache): element_id=UUID("11111111-1111-1111-1111-111111111111"), text="The", confidence=0.75, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ), CachedTranscription( @@ -561,6 +570,7 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache): element_id=UUID("11111111-1111-1111-1111-111111111111"), text="word", confidence=0.42, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ), ] @@ -1008,10 +1018,20 @@ def test_create_element_transcriptions(responses, mock_elements_worker): ("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"), ] + sent_transcriptions_sample = [ + { + "text": tr["text"], + "score": tr["score"], + "orientation": tr["orientation"].value, + "polygon": tr["polygon"] + } + for tr in TRANSCRIPTIONS_SAMPLE + ] + assert json.loads(responses.calls[-1].request.body) == { "element_type": "page", "worker_version": "12341234-1234-1234-1234-123412341234", - "transcriptions": TRANSCRIPTIONS_SAMPLE, + "transcriptions": sent_transcriptions_sample, "return_elements": True, } assert annotations == [ @@ -1074,10 +1094,20 @@ def test_create_element_transcriptions_with_cache( ("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"), ] + sent_transcriptions_sample = [ + { + "text": tr["text"], + "score": tr["score"], + "orientation": tr["orientation"].value, + "polygon": tr["polygon"] + } + for tr in TRANSCRIPTIONS_SAMPLE + ] + assert json.loads(responses.calls[-1].request.body) == { "element_type": "page", "worker_version": "12341234-1234-1234-1234-123412341234", - "transcriptions": TRANSCRIPTIONS_SAMPLE, + "transcriptions": sent_transcriptions_sample, "return_elements": True, } assert annotations == [ @@ -1121,6 +1151,7 @@ def test_create_element_transcriptions_with_cache( element_id=UUID("11111111-1111-1111-1111-111111111111"), text="The", confidence=0.5, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ), CachedTranscription( @@ -1128,6 +1159,7 @@ def test_create_element_transcriptions_with_cache( element_id=UUID("22222222-2222-2222-2222-222222222222"), text="first", confidence=0.75, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ), CachedTranscription( @@ -1135,6 +1167,7 @@ def test_create_element_transcriptions_with_cache( element_id=UUID("11111111-1111-1111-1111-111111111111"), text="line", confidence=0.9, + orientation=TextOrientation.HorizontalLeftToRight, worker_version_id=UUID("12341234-1234-1234-1234-123412341234"), ), ] -- GitLab