From a11872b5d99ad5bb60447e70e518596626441e62 Mon Sep 17 00:00:00 2001
From: Martin Maarand <maarand@teklia.com>
Date: Thu, 18 Nov 2021 16:32:49 +0000
Subject: [PATCH] Resolve "Support confidence in create_transcription_entity
 helper"

---
 arkindex_worker/cache.py                    |  1 +
 arkindex_worker/worker/entity.py            |  9 ++-
 tests/test_cache.py                         |  2 +-
 tests/test_elements_worker/test_entities.py | 74 +++++++++++++++++++++
 4 files changed, 84 insertions(+), 2 deletions(-)

diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py
index b184460d..b1be02e2 100644
--- a/arkindex_worker/cache.py
+++ b/arkindex_worker/cache.py
@@ -166,6 +166,7 @@ class CachedTranscriptionEntity(Model):
     offset = IntegerField(constraints=[Check("offset >= 0")])
     length = IntegerField(constraints=[Check("length > 0")])
     worker_version_id = UUIDField()
+    confidence = FloatField(null=True)
 
     class Meta:
         primary_key = CompositeKey("transcription", "entity")
diff --git a/arkindex_worker/worker/entity.py b/arkindex_worker/worker/entity.py
index 1cf56165..449e8ce3 100644
--- a/arkindex_worker/worker/entity.py
+++ b/arkindex_worker/worker/entity.py
@@ -82,7 +82,9 @@ class EntityMixin(object):
 
         return entity["id"]
 
-    def create_transcription_entity(self, transcription, entity, offset, length):
+    def create_transcription_entity(
+        self, transcription, entity, offset, length, confidence=None
+    ):
         """
         Create a link between an existing entity and an existing transcription through API
         """
@@ -98,6 +100,9 @@ class EntityMixin(object):
         assert (
             length is not None and isinstance(length, int) and length > 0
         ), "length shouldn't be null and should be a strictly positive integer"
+        assert (
+            confidence is None or isinstance(confidence, float) and 0 <= confidence <= 1
+        ), "confidence should be null or a float in [0..1] range"
         if self.is_read_only:
             logger.warning(
                 "Cannot create transcription entity as this worker is in read-only mode"
@@ -112,6 +117,7 @@ class EntityMixin(object):
                 "length": length,
                 "offset": offset,
                 "worker_version_id": self.worker_version_id,
+                "confidence": confidence,
             },
         )
         # TODO: Report transcription entity creation
@@ -125,6 +131,7 @@ class EntityMixin(object):
                     offset=offset,
                     length=length,
                     worker_version_id=self.worker_version_id,
+                    confidence=confidence,
                 )
             except IntegrityError as e:
                 logger.warning(
diff --git a/tests/test_cache.py b/tests/test_cache.py
index 03a68144..fcc7e8e7 100644
--- a/tests/test_cache.py
+++ b/tests/test_cache.py
@@ -58,7 +58,7 @@ def test_create_tables(tmp_path):
 CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "rotation_angle" INTEGER NOT NULL, "mirrored" INTEGER NOT NULL, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
 CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
 CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
-CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
+CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), "worker_version_id" TEXT NOT NULL, "confidence" REAL, PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
 CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "orientation" VARCHAR(50) NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
 
     actual_schema = "\n".join(
diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py
index e4f43b6f..075f8e5b 100644
--- a/tests/test_elements_worker/test_entities.py
+++ b/tests/test_elements_worker/test_entities.py
@@ -451,6 +451,7 @@ def test_create_transcription_entity(responses, mock_elements_worker):
         "offset": 5,
         "length": 10,
         "worker_version_id": "12341234-1234-1234-1234-123412341234",
+        "confidence": None,
     }
 
 
@@ -508,6 +509,7 @@ def test_create_transcription_entity_with_cache(
         "offset": 5,
         "length": 10,
         "worker_version_id": "12341234-1234-1234-1234-123412341234",
+        "confidence": None,
     }
 
     # Check that created transcription entity was properly stored in SQLite cache
@@ -520,3 +522,75 @@ def test_create_transcription_entity_with_cache(
             worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
         )
     ]
+
+
+def test_create_transcription_entity_with_confidence_with_cache(
+    responses, mock_elements_worker_with_cache
+):
+    CachedElement.create(
+        id=UUID("12341234-1234-1234-1234-123412341234"),
+        type="page",
+    )
+    CachedTranscription.create(
+        id=UUID("11111111-1111-1111-1111-111111111111"),
+        element=UUID("12341234-1234-1234-1234-123412341234"),
+        text="Hello, it's me.",
+        confidence=0.42,
+        orientation=TextOrientation.HorizontalLeftToRight,
+        worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
+    )
+    CachedEntity.create(
+        id=UUID("11111111-1111-1111-1111-111111111111"),
+        type="person",
+        name="Bob Bob",
+        worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
+    )
+
+    responses.add(
+        responses.POST,
+        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
+        status=200,
+        json={
+            "entity": "11111111-1111-1111-1111-111111111111",
+            "offset": 5,
+            "length": 10,
+            "confidence": 0.77,
+        },
+    )
+
+    mock_elements_worker_with_cache.create_transcription_entity(
+        transcription="11111111-1111-1111-1111-111111111111",
+        entity="11111111-1111-1111-1111-111111111111",
+        offset=5,
+        length=10,
+        confidence=0.77,
+    )
+
+    assert len(responses.calls) == len(BASE_API_CALLS) + 1
+    assert [
+        (call.request.method, call.request.url) for call in responses.calls
+    ] == BASE_API_CALLS + [
+        (
+            "POST",
+            "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
+        ),
+    ]
+    assert json.loads(responses.calls[-1].request.body) == {
+        "entity": "11111111-1111-1111-1111-111111111111",
+        "offset": 5,
+        "length": 10,
+        "worker_version_id": "12341234-1234-1234-1234-123412341234",
+        "confidence": 0.77,
+    }
+
+    # Check that created transcription entity was properly stored in SQLite cache
+    assert list(CachedTranscriptionEntity.select()) == [
+        CachedTranscriptionEntity(
+            transcription=UUID("11111111-1111-1111-1111-111111111111"),
+            entity=UUID("11111111-1111-1111-1111-111111111111"),
+            offset=5,
+            length=10,
+            worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
+            confidence=0.77,
+        )
+    ]
-- 
GitLab