From 827abba105bd7eddf05fafbb88d97f2749911a0c Mon Sep 17 00:00:00 2001
From: mlbonhomme <bonhomme@teklia.com>
Date: Tue, 9 Nov 2021 16:51:16 +0100
Subject: [PATCH] actually check the text orientation in the cache

---
 arkindex_worker/worker/transcription.py       |  14 +-
 tests/conftest.py                             |   2 +
 .../test_transcriptions.py                    | 384 +++++++++++++++++-
 3 files changed, 396 insertions(+), 4 deletions(-)

diff --git a/arkindex_worker/worker/transcription.py b/arkindex_worker/worker/transcription.py
index 49903703..c415beb1 100644
--- a/arkindex_worker/worker/transcription.py
+++ b/arkindex_worker/worker/transcription.py
@@ -107,12 +107,22 @@ class TranscriptionMixin(object):
             assert orientation and isinstance(
                 orientation, TextOrientation
             ), f"Transcription at index {index} in transcriptions: orientation shouldn't be null and should be of type TextOrientation"
+            transcription["orientation"] = orientation
+
+        sent_transcriptions = [
+            {
+                "text": transcription["text"],
+                "score": transcription["score"],
+                "orientation": transcription["orientation"].value,
+            }
+            for transcription in transcriptions
+        ]
 
         created_trs = self.request(
             "CreateTranscriptions",
             body={
                 "worker_version": self.worker_version_id,
-                "transcriptions": transcriptions,
+                "transcriptions": sent_transcriptions,
             },
         )["transcriptions"]
 
@@ -257,7 +267,7 @@ class TranscriptionMixin(object):
                         "element_id": annotation["element_id"],
                         "text": transcription["text"],
                         "confidence": transcription["score"],
-                        "orientation": transcription["orientation"],
+                        "orientation": transcription["orientation"].value,
                         "worker_version_id": self.worker_version_id,
                     }
                 )
diff --git a/tests/conftest.py b/tests/conftest.py
index 8cdff6f8..aa7e521f 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -467,6 +467,7 @@ def mock_databases(tmpdir):
             element_id=UUID("42424242-4242-4242-4242-424242424242"),
             text="Hello!",
             confidence=0.42,
+            orientation=TextOrientation.HorizontalLeftToRight,
             worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
         )
 
@@ -483,6 +484,7 @@ def mock_databases(tmpdir):
             element_id=UUID("42424242-4242-4242-4242-424242424242"),
             text="Hello again neighbor !",
             confidence=0.42,
+            orientation=TextOrientation.HorizontalLeftToRight,
             worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
         )
 
diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py
index 7277b917..39e0420f 100644
--- a/tests/test_elements_worker/test_transcriptions.py
+++ b/tests/test_elements_worker/test_transcriptions.py
@@ -118,6 +118,76 @@ def test_create_transcription_wrong_score(mock_elements_worker):
     )
 
 
+def test_create_transcription_default_orientation(responses, mock_elements_worker):
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
+    responses.add(
+        responses.POST,
+        f"http://testserver/api/v1/element/{elt.id}/transcription/",
+        status=200,
+        json={
+            "id": "56785678-5678-5678-5678-567856785678",
+            "text": "Animula vagula blandula",
+            "score": 0.42,
+            "confidence": 0.42,
+            "worker_version_id": "12341234-1234-1234-1234-123412341234",
+        },
+    )
+    mock_elements_worker.create_transcription(
+        element=elt,
+        text="Animula vagula blandula",
+        score=0.42,
+    )
+    assert json.loads(responses.calls[-1].request.body) == {
+        "text": "Animula vagula blandula",
+        "worker_version": "12341234-1234-1234-1234-123412341234",
+        "score": 0.42,
+        "orientation": "horizontal-lr",
+    }
+
+
+def test_create_transcription_orientation(responses, mock_elements_worker):
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
+    responses.add(
+        responses.POST,
+        f"http://testserver/api/v1/element/{elt.id}/transcription/",
+        status=200,
+        json={
+            "id": "56785678-5678-5678-5678-567856785678",
+            "text": "Animula vagula blandula",
+            "score": 0.42,
+            "confidence": 0.42,
+            "worker_version_id": "12341234-1234-1234-1234-123412341234",
+        },
+    )
+    mock_elements_worker.create_transcription(
+        element=elt,
+        text="Animula vagula blandula",
+        orientation=TextOrientation.VerticalLeftToRight,
+        score=0.42,
+    )
+    assert json.loads(responses.calls[-1].request.body) == {
+        "text": "Animula vagula blandula",
+        "worker_version": "12341234-1234-1234-1234-123412341234",
+        "score": 0.42,
+        "orientation": "vertical-lr",
+    }
+
+
+def test_create_transcription_wrong_orientation(mock_elements_worker):
+    elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_transcription(
+            element=elt,
+            text="Animula vagula blandula",
+            score=0.26,
+            orientation="eliptical",
+        )
+    assert (
+        str(e.value)
+        == "orientation shouldn't be null and should be of type TextOrientation"
+    )
+
+
 def test_create_transcription_api_error(responses, mock_elements_worker):
     elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
     responses.add(
@@ -203,7 +273,6 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca
         element=elt,
         text="i am a line",
         score=0.42,
-        orientation=TextOrientation.HorizontalLeftToRight,
     )
 
     assert len(responses.calls) == len(BASE_API_CALLS) + 1
@@ -233,6 +302,66 @@ def test_create_transcription_with_cache(responses, mock_elements_worker_with_ca
     ]
 
 
+def test_create_transcription_orientation_with_cache(
+    responses, mock_elements_worker_with_cache
+):
+    elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
+    responses.add(
+        responses.POST,
+        f"http://testserver/api/v1/element/{elt.id}/transcription/",
+        status=200,
+        json={
+            "id": "56785678-5678-5678-5678-567856785678",
+            "text": "Animula vagula blandula",
+            "score": 0.42,
+            "confidence": 0.42,
+            "orientation": "vertical-lr",
+            "worker_version_id": "12341234-1234-1234-1234-123412341234",
+        },
+    )
+    mock_elements_worker_with_cache.create_transcription(
+        element=elt,
+        text="Animula vagula blandula",
+        orientation=TextOrientation.VerticalLeftToRight,
+        score=0.42,
+    )
+    assert json.loads(responses.calls[-1].request.body) == {
+        "text": "Animula vagula blandula",
+        "worker_version": "12341234-1234-1234-1234-123412341234",
+        "orientation": "vertical-lr",
+        "score": 0.42,
+    }
+    # Check that the text orientation was properly stored in SQLite cache
+    assert list(CachedTranscription.select()) == [
+        CachedTranscription(
+            id=UUID("56785678-5678-5678-5678-567856785678"),
+        )
+    ]
+    assert [
+        {
+            k: getattr(transcription, k)
+            for k in [
+                "id",
+                "element_id",
+                "text",
+                "confidence",
+                "orientation",
+                "worker_version_id",
+            ]
+        }
+        for transcription in CachedTranscription.select()
+    ] == [
+        {
+            "id": UUID("56785678-5678-5678-5678-567856785678"),
+            "element_id": UUID("12341234-1234-1234-1234-123412341234"),
+            "text": "Animula vagula blandula",
+            "confidence": 0.42,
+            "orientation": TextOrientation.VerticalLeftToRight.value,
+            "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
+        }
+    ]
+
+
 def test_create_transcriptions_wrong_transcriptions(mock_elements_worker):
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.create_transcriptions(
@@ -463,6 +592,27 @@ def test_create_transcriptions_wrong_transcriptions(mock_elements_worker):
         == "Transcription at index 1 in transcriptions: score shouldn't be null and should be a float in [0..1] range"
     )
 
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_transcriptions(
+            transcriptions=[
+                {
+                    "element_id": "11111111-1111-1111-1111-111111111111",
+                    "text": "The",
+                    "score": 0.75,
+                },
+                {
+                    "element_id": "11111111-1111-1111-1111-111111111111",
+                    "text": "word",
+                    "score": 0.28,
+                    "orientation": "wobble",
+                },
+            ],
+        )
+    assert (
+        str(e.value)
+        == "Transcription at index 1 in transcriptions: orientation shouldn't be null and should be of type TextOrientation"
+    )
+
 
 def test_create_transcriptions_api_error(responses, mock_elements_worker):
     responses.add(
@@ -552,7 +702,14 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
 
     assert json.loads(responses.calls[-1].request.body) == {
         "worker_version": "12341234-1234-1234-1234-123412341234",
-        "transcriptions": trans,
+        "transcriptions": [
+            {
+                "text": tr["text"],
+                "score": tr["score"],
+                "orientation": tr["orientation"].value,
+            }
+            for tr in trans
+        ],
     }
 
     # Check that created transcriptions were properly stored in SQLite cache
@@ -576,6 +733,65 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
     ]
 
 
+def test_create_transcriptions_orientation(responses, mock_elements_worker_with_cache):
+    CachedElement.create(id="11111111-1111-1111-1111-111111111111", type="thing")
+    trans = [
+        {
+            "element_id": "11111111-1111-1111-1111-111111111111",
+            "text": "Animula vagula blandula",
+            "score": 0.12,
+            "orientation": TextOrientation.HorizontalRightToLeft,
+        },
+        {
+            "element_id": "11111111-1111-1111-1111-111111111111",
+            "text": "Hospes comesque corporis",
+            "score": 0.21,
+            "orientation": TextOrientation.VerticalLeftToRight,
+        },
+    ]
+
+    responses.add(
+        responses.POST,
+        "http://testserver/api/v1/transcription/bulk/",
+        status=200,
+        json={
+            "worker_version": "12341234-1234-1234-1234-123412341234",
+            "transcriptions": [
+                {
+                    "id": "00000000-0000-0000-0000-000000000000",
+                    "element_id": "11111111-1111-1111-1111-111111111111",
+                    "text": "Animula vagula blandula",
+                    "orientation": "horizontal-rl",
+                    "confidence": 0.12,
+                },
+                {
+                    "id": "11111111-1111-1111-1111-111111111111",
+                    "element_id": "11111111-1111-1111-1111-111111111111",
+                    "text": "Hospes comesque corporis",
+                    "orientation": "vertical-lr",
+                    "confidence": 0.21,
+                },
+            ],
+        },
+    )
+
+    mock_elements_worker_with_cache.create_transcriptions(
+        transcriptions=trans,
+    )
+
+    assert json.loads(responses.calls[-1].request.body) == {
+        "worker_version": "12341234-1234-1234-1234-123412341234",
+        "transcriptions": [
+            {
+                "text": tr["text"],
+                "score": tr["score"],
+                "orientation": tr["orientation"].value,
+            }
+            for tr in trans
+        ],
+    }
+
+
 def test_create_element_transcriptions_wrong_element(mock_elements_worker):
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.create_element_transcriptions(
@@ -951,6 +1167,29 @@ def test_create_element_transcriptions_wrong_transcriptions(mock_elements_worker
         == "Transcription at index 1 in transcriptions: polygon points should be lists of two numbers"
     )
 
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_element_transcriptions(
+            element=elt,
+            sub_element_type="page",
+            transcriptions=[
+                {
+                    "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
+                    "score": 0.75,
+                    "text": "The",
+                },
+                {
+                    "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
+                    "score": 0.35,
+                    "text": "word",
+                    "orientation": "uptown",
+                },
+            ],
+        )
+        assert (
+            str(e.value)
+            == "Transcription at index 1 in transcriptions: orientation shouldn't be null and should be of type TextOrientation"
+        )
+
 
 def test_create_element_transcriptions_api_error(responses, mock_elements_worker):
     elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
@@ -1173,6 +1412,147 @@ def test_create_element_transcriptions_with_cache(
     ]
 
 
+def test_create_transcriptions_orientation_with_cache(
+    responses, mock_elements_worker_with_cache
+):
+    elt = CachedElement(id="12341234-1234-1234-1234-123412341234", type="thing")
+
+    responses.add(
+        responses.POST,
+        f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
+        status=200,
+        json=[
+            {
+                "id": "56785678-5678-5678-5678-567856785678",
+                "element_id": "11111111-1111-1111-1111-111111111111",
+                "created": True,
+            },
+            {
+                "id": "67896789-6789-6789-6789-678967896789",
+                "element_id": "22222222-2222-2222-2222-222222222222",
+                "created": False,
+            },
+            {
+                "id": "78907890-7890-7890-7890-789078907890",
+                "element_id": "11111111-1111-1111-1111-111111111111",
+                "created": True,
+            },
+        ],
+    )
+
+    oriented_transcriptions = [
+        {
+            "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
+            "score": 0.5,
+            "text": "Animula vagula blandula",
+        },
+        {
+            "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
+            "score": 0.75,
+            "text": "Hospes comesque corporis",
+            "orientation": TextOrientation.VerticalLeftToRight,
+        },
+        {
+            "polygon": [[1000, 300], [1200, 300], [1200, 500], [1000, 500]],
+            "score": 0.9,
+            "text": "Quae nunc abibis in loca",
+            "orientation": TextOrientation.HorizontalRightToLeft,
+        },
+    ]
+
+    annotations = mock_elements_worker_with_cache.create_element_transcriptions(
+        element=elt,
+        sub_element_type="page",
+        transcriptions=oriented_transcriptions,
+    )
+
+    sent_oriented_transcriptions = [
+        {
+            "text": tr["text"],
+            "score": tr["score"],
+            "orientation": tr["orientation"].value,
+            "polygon": tr["polygon"],
+        }
+        for tr in oriented_transcriptions
+    ]
+
+    assert json.loads(responses.calls[-1].request.body) == {
+        "element_type": "page",
+        "worker_version": "12341234-1234-1234-1234-123412341234",
+        "transcriptions": sent_oriented_transcriptions,
+        "return_elements": True,
+    }
+    assert annotations == [
+        {
+            "id": "56785678-5678-5678-5678-567856785678",
+            "element_id": "11111111-1111-1111-1111-111111111111",
+            "created": True,
+        },
+        {
+            "id": "67896789-6789-6789-6789-678967896789",
+            "element_id": "22222222-2222-2222-2222-222222222222",
+            "created": False,
+        },
+        {
+            "id": "78907890-7890-7890-7890-789078907890",
+            "element_id": "11111111-1111-1111-1111-111111111111",
+            "created": True,
+        },
+    ]
+
+    # Check that the text orientation was properly stored in SQLite cache
+    assert list(CachedTranscription.select()) == [
+        CachedTranscription(
+            id=UUID("56785678-5678-5678-5678-567856785678"),
+        ),
+        CachedTranscription(
+            id=UUID("67896789-6789-6789-6789-678967896789"),
+        ),
+        CachedTranscription(
+            id=UUID("78907890-7890-7890-7890-789078907890"),
+        ),
+    ]
+    assert [
+        {
+            k: getattr(transcription, k)
+            for k in [
+                "id",
+                "element_id",
+                "text",
+                "confidence",
+                "orientation",
+                "worker_version_id",
+            ]
+        }
+        for transcription in CachedTranscription.select()
+    ] == [
+        {
+            "id": UUID("56785678-5678-5678-5678-567856785678"),
+            "element_id": UUID("11111111-1111-1111-1111-111111111111"),
+            "text": "Animula vagula blandula",
+            "confidence": 0.5,
+            "orientation": TextOrientation.HorizontalLeftToRight.value,
+            "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
+        },
+        {
+            "id": UUID("67896789-6789-6789-6789-678967896789"),
+            "element_id": UUID("22222222-2222-2222-2222-222222222222"),
+            "text": "Hospes comesque corporis",
+            "confidence": 0.75,
+            "orientation": TextOrientation.VerticalLeftToRight.value,
+            "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
+        },
+        {
+            "id": UUID("78907890-7890-7890-7890-789078907890"),
+            "element_id": UUID("11111111-1111-1111-1111-111111111111"),
+            "text": "Quae nunc abibis in loca",
+            "confidence": 0.9,
+            "orientation": TextOrientation.HorizontalRightToLeft.value,
+            "worker_version_id": UUID("12341234-1234-1234-1234-123412341234"),
+        },
+    ]
+
+
 def test_list_transcriptions_wrong_element(mock_elements_worker):
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.list_transcriptions(element=None)
-- 
GitLab