From 661eccd361fac03c5aae2000643e3ad6d468bad5 Mon Sep 17 00:00:00 2001
From: mlbonhomme <bonhomme@teklia.com>
Date: Tue, 9 Nov 2021 13:21:46 +0100
Subject: [PATCH] support text orientation in base worker

---
 arkindex_worker/cache.py                      |  1 +
 arkindex_worker/worker/transcription.py       | 48 +++++++++++++++++--
 tests/conftest.py                             |  6 +++
 tests/test_cache.py                           |  2 +-
 tests/test_elements_worker/test_entities.py   |  2 +
 .../test_transcriptions.py                    | 37 +++++++++++++-
 6 files changed, 90 insertions(+), 6 deletions(-)

diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py
index 30141046..b184460d 100644
--- a/arkindex_worker/cache.py
+++ b/arkindex_worker/cache.py
@@ -124,6 +124,7 @@ class CachedTranscription(Model):
     element = ForeignKeyField(CachedElement, backref="transcriptions")
     text = TextField()
     confidence = FloatField()
+    orientation = CharField(max_length=50)
     worker_version_id = UUIDField()
 
     class Meta:
diff --git a/arkindex_worker/worker/transcription.py b/arkindex_worker/worker/transcription.py
index 388f756c..49903703 100644
--- a/arkindex_worker/worker/transcription.py
+++ b/arkindex_worker/worker/transcription.py
@@ -1,5 +1,7 @@
 # -*- coding: utf-8 -*-
 
+from enum import Enum
+
 from peewee import IntegrityError
 
 from arkindex_worker import logger
@@ -7,8 +9,17 @@ from arkindex_worker.cache import CachedElement, CachedTranscription
 from arkindex_worker.models import Element
 
 
+class TextOrientation(Enum):
+    HorizontalLeftToRight = "horizontal-lr"
+    HorizontalRightToLeft = "horizontal-rl"
+    VerticalRightToLeft = "vertical-rl"
+    VerticalLeftToRight = "vertical-lr"
+
+
 class TranscriptionMixin(object):
-    def create_transcription(self, element, text, score):
+    def create_transcription(
+        self, element, text, score, orientation=TextOrientation.HorizontalLeftToRight
+    ):
         """
         Create a transcription on the given element through the API.
         """
@@ -18,7 +29,9 @@ class TranscriptionMixin(object):
         assert text and isinstance(
             text, str
         ), "text shouldn't be null and should be of type str"
-
+        assert orientation and isinstance(
+            orientation, TextOrientation
+        ), "orientation shouldn't be null and should be of type TextOrientation"
         assert (
             isinstance(score, float) and 0 <= score <= 1
         ), "score shouldn't be null and should be a float in [0..1] range"
@@ -36,6 +49,7 @@ class TranscriptionMixin(object):
                 "text": text,
                 "worker_version": self.worker_version_id,
                 "score": score,
+                "orientation": orientation.value,
             },
         )
 
@@ -50,6 +64,7 @@ class TranscriptionMixin(object):
                         "element_id": element.id,
                         "text": created["text"],
                         "confidence": created["confidence"],
+                        "orientation": created["orientation"],
                         "worker_version_id": self.worker_version_id,
                     }
                 ]
@@ -86,6 +101,13 @@ class TranscriptionMixin(object):
                 score is not None and isinstance(score, float) and 0 <= score <= 1
             ), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
 
+            orientation = transcription.get(
+                "orientation", TextOrientation.HorizontalLeftToRight
+            )
+            assert orientation and isinstance(
+                orientation, TextOrientation
+            ), f"Transcription at index {index} in transcriptions: orientation shouldn't be null and should be of type TextOrientation"
+
         created_trs = self.request(
             "CreateTranscriptions",
             body={
@@ -106,6 +128,7 @@ class TranscriptionMixin(object):
                         "element_id": created_tr["element_id"],
                         "text": created_tr["text"],
                         "confidence": created_tr["confidence"],
+                        "orientation": created_tr["orientation"],
                         "worker_version_id": self.worker_version_id,
                     }
                     for created_tr in created_trs
@@ -143,6 +166,14 @@ class TranscriptionMixin(object):
                 score is not None and isinstance(score, float) and 0 <= score <= 1
             ), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
 
+            orientation = transcription.get(
+                "orientation", TextOrientation.HorizontalLeftToRight
+            )
+            assert orientation and isinstance(
+                orientation, TextOrientation
+            ), f"Transcription at index {index} in transcriptions: orientation shouldn't be null and should be of type TextOrientation"
+            transcription["orientation"] = orientation
+
             polygon = transcription.get("polygon")
             assert polygon and isinstance(
                 polygon, list
@@ -162,13 +193,23 @@ class TranscriptionMixin(object):
             )
             return
 
+        sent_transcriptions = [
+            {
+                "text": transcription["text"],
+                "score": transcription["score"],
+                "orientation": transcription["orientation"].value,
+                "polygon": transcription["polygon"],
+            }
+            for transcription in transcriptions
+        ]
+
         annotations = self.request(
             "CreateElementTranscriptions",
             id=element.id,
             body={
                 "element_type": sub_element_type,
                 "worker_version": self.worker_version_id,
-                "transcriptions": transcriptions,
+                "transcriptions": sent_transcriptions,
                 "return_elements": True,
             },
         )
@@ -216,6 +257,7 @@ class TranscriptionMixin(object):
                         "element_id": annotation["element_id"],
                         "text": transcription["text"],
                         "confidence": transcription["score"],
+                        "orientation": transcription["orientation"],
                         "worker_version_id": self.worker_version_id,
                     }
                 )
diff --git a/tests/conftest.py b/tests/conftest.py
index 5a4596c6..8cdff6f8 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -15,6 +15,7 @@ from arkindex.mock import MockApiClient
 from arkindex_worker.cache import MODELS, CachedElement, CachedTranscription
 from arkindex_worker.git import GitHelper, GitlabHelper
 from arkindex_worker.worker import BaseWorker, ElementsWorker
+from arkindex_worker.worker.transcription import TextOrientation
 
 FIXTURES_DIR = Path(__file__).resolve().parent / "data"
 
@@ -381,6 +382,7 @@ def mock_cached_transcriptions():
         element_id=UUID("11111111-1111-1111-1111-111111111111"),
         text="This",
         confidence=0.42,
+        orientation=TextOrientation.HorizontalLeftToRight,
         worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
     )
     CachedTranscription.create(
@@ -388,6 +390,7 @@ def mock_cached_transcriptions():
         element_id=UUID("22222222-2222-2222-2222-222222222222"),
         text="is",
         confidence=0.42,
+        orientation=TextOrientation.HorizontalLeftToRight,
         worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
     )
     CachedTranscription.create(
@@ -395,6 +398,7 @@ def mock_cached_transcriptions():
         element_id=UUID("33333333-3333-3333-3333-333333333333"),
         text="a",
         confidence=0.42,
+        orientation=TextOrientation.HorizontalLeftToRight,
         worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
     )
     CachedTranscription.create(
@@ -402,6 +406,7 @@ def mock_cached_transcriptions():
         element_id=UUID("44444444-4444-4444-4444-444444444444"),
         text="good",
         confidence=0.42,
+        orientation=TextOrientation.HorizontalLeftToRight,
         worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
     )
     CachedTranscription.create(
@@ -409,6 +414,7 @@ def mock_cached_transcriptions():
         element_id=UUID("55555555-5555-5555-5555-555555555555"),
         text="test",
         confidence=0.42,
+        orientation=TextOrientation.HorizontalLeftToRight,
         worker_version_id=UUID("90129012-9012-9012-9012-901290129012"),
     )
 
diff --git a/tests/test_cache.py b/tests/test_cache.py
index 121fa580..03a68144 100644
--- a/tests/test_cache.py
+++ b/tests/test_cache.py
@@ -59,7 +59,7 @@ CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type
 CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
 CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
 CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
-CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
+CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
 
     actual_schema = "\n".join(
         [
diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py
index 8b16ba15..e4f43b6f 100644
--- a/tests/test_elements_worker/test_entities.py
+++ b/tests/test_elements_worker/test_entities.py
@@ -13,6 +13,7 @@ from arkindex_worker.cache import (
 )
 from arkindex_worker.models import Element
 from arkindex_worker.worker import EntityType
+from arkindex_worker.worker.transcription import TextOrientation
 
 from . import BASE_API_CALLS
 
@@ -465,6 +466,7 @@ def test_create_transcription_entity_with_cache(
         element=UUID("12341234-1234-1234-1234-123412341234"),
         text="Hello, it's me.",
         confidence=0.42,
+        orientation=TextOrientation.HorizontalLeftToRight,
         worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
     )
     CachedEntity.create(
diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py
index d7166b54..1b59548d 100644
--- a/tests/test_elements_worker/test_transcriptions.py
+++ b/tests/test_elements_worker/test_transcriptions.py
@@ -7,6 +7,7 @@ from apistar.exceptions import ErrorResponse
 
 from arkindex_worker.cache import CachedElement, CachedTranscription
 from arkindex_worker.models import Element
+from arkindex_worker.worker.transcription import TextOrientation
 
 from . import BASE_API_CALLS
 
@@ -177,6 +178,7 @@ def test_create_transcription(responses, mock_elements_worker):
         "text": "i am a line",
         "worker_version": "12341234-1234-1234-1234-123412341234",
         "score": 0.42,
+        "orientation": "horizontal-lr",
     }
 
 
@@ -192,6 +194,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca
             "text": "i am a line",
             "score": 0.42,
             "confidence": 0.42,
+            "orientation": "horizontal-lr",
             "worker_version_id": "12341234-1234-1234-1234-123412341234",
         },
     )
@@ -200,6 +203,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca
         element=elt,
         text="i am a line",
         score=0.42,
+        orientation=TextOrientation.HorizontalLeftToRight,
     )
 
     assert len(responses.calls) == len(BASE_API_CALLS) + 1
@@ -212,6 +216,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca
     assert json.loads(responses.calls[-1].request.body) == {
         "text": "i am a line",
         "worker_version": "12341234-1234-1234-1234-123412341234",
+        "orientation": "horizontal-lr",
         "score": 0.42,
     }
 
@@ -222,6 +227,7 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca
             element_id=UUID(elt.id),
             text="i am a line",
             confidence=0.42,
+            orientation=TextOrientation.HorizontalLeftToRight,
             worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
         )
     ]
@@ -519,12 +525,14 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
                     "id": "00000000-0000-0000-0000-000000000000",
                     "element_id": "11111111-1111-1111-1111-111111111111",
                     "text": "The",
+                    "orientation": "horizontal-lr",
                     "confidence": 0.75,
                 },
                 {
                     "id": "11111111-1111-1111-1111-111111111111",
                     "element_id": "11111111-1111-1111-1111-111111111111",
                     "text": "word",
+                    "orientation": "horizontal-lr",
                     "confidence": 0.42,
                 },
             ],
@@ -554,6 +562,7 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
             element_id=UUID("11111111-1111-1111-1111-111111111111"),
             text="The",
             confidence=0.75,
+            orientation=TextOrientation.HorizontalLeftToRight,
             worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
         ),
         CachedTranscription(
@@ -561,6 +570,7 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
             element_id=UUID("11111111-1111-1111-1111-111111111111"),
             text="word",
             confidence=0.42,
+            orientation=TextOrientation.HorizontalLeftToRight,
             worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
         ),
     ]
@@ -1008,10 +1018,20 @@ def test_create_element_transcriptions(responses, mock_elements_worker):
         ("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"),
     ]
 
+    sent_transcriptions_sample = [
+        {
+            "text": tr["text"],
+            "score": tr["score"],
+            "orientation": tr["orientation"].value,
+            "polygon": tr["polygon"]
+        }
+        for tr in TRANSCRIPTIONS_SAMPLE
+    ]
+
     assert json.loads(responses.calls[-1].request.body) == {
         "element_type": "page",
         "worker_version": "12341234-1234-1234-1234-123412341234",
-        "transcriptions": TRANSCRIPTIONS_SAMPLE,
+        "transcriptions": sent_transcriptions_sample,
         "return_elements": True,
     }
     assert annotations == [
@@ -1074,10 +1094,20 @@ def test_create_element_transcriptions_with_cache(
         ("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"),
     ]
 
+    sent_transcriptions_sample = [
+        {
+            "text": tr["text"],
+            "score": tr["score"],
+            "orientation": tr["orientation"].value,
+            "polygon": tr["polygon"]
+        }
+        for tr in TRANSCRIPTIONS_SAMPLE
+    ]
+
     assert json.loads(responses.calls[-1].request.body) == {
         "element_type": "page",
         "worker_version": "12341234-1234-1234-1234-123412341234",
-        "transcriptions": TRANSCRIPTIONS_SAMPLE,
+        "transcriptions": sent_transcriptions_sample,
         "return_elements": True,
     }
     assert annotations == [
@@ -1121,6 +1151,7 @@ def test_create_element_transcriptions_with_cache(
             element_id=UUID("11111111-1111-1111-1111-111111111111"),
             text="The",
             confidence=0.5,
+            orientation=TextOrientation.HorizontalLeftToRight,
             worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
         ),
         CachedTranscription(
@@ -1128,6 +1159,7 @@ def test_create_element_transcriptions_with_cache(
             element_id=UUID("22222222-2222-2222-2222-222222222222"),
             text="first",
             confidence=0.75,
+            orientation=TextOrientation.HorizontalLeftToRight,
             worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
         ),
         CachedTranscription(
@@ -1135,6 +1167,7 @@ def test_create_element_transcriptions_with_cache(
             element_id=UUID("11111111-1111-1111-1111-111111111111"),
             text="line",
             confidence=0.9,
+            orientation=TextOrientation.HorizontalLeftToRight,
             worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
         ),
     ]
-- 
GitLab