From f0a0a7ac4a7e192e152532b84f0fd60e00ef6bbf Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Mon, 26 Apr 2021 15:20:47 +0000
Subject: [PATCH] Implement create_transcription_entity + Support local cache

---
 arkindex_worker/cache.py                    |  17 ++
 arkindex_worker/worker/entity.py            |  52 ++++-
 tests/test_cache.py                         |   1 +
 tests/test_elements_worker/test_entities.py | 236 +++++++++++++++++++-
 4 files changed, 304 insertions(+), 2 deletions(-)

diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py
index 4505f394..92be4913 100644
--- a/arkindex_worker/cache.py
+++ b/arkindex_worker/cache.py
@@ -6,6 +6,8 @@ import sqlite3
 from peewee import (
     BooleanField,
     CharField,
+    Check,
+    CompositeKey,
     Field,
     FloatField,
     ForeignKeyField,
@@ -138,6 +140,20 @@ class CachedEntity(Model):
         table_name = "entities"
 
 
+class CachedTranscriptionEntity(Model):
+    transcription = ForeignKeyField(
+        CachedTranscription, backref="transcription_entities"
+    )
+    entity = ForeignKeyField(CachedEntity, backref="transcription_entities")
+    offset = IntegerField(constraints=[Check("offset >= 0")])
+    length = IntegerField(constraints=[Check("length > 0")])
+
+    class Meta:
+        primary_key = CompositeKey("transcription", "entity")
+        database = db
+        table_name = "transcription_entities"
+
+
 # Add all the managed models in that list
 # It's used here, but also in unit tests
 MODELS = [
@@ -146,6 +162,7 @@ MODELS = [
     CachedTranscription,
     CachedClassification,
     CachedEntity,
+    CachedTranscriptionEntity,
 ]
 
 
diff --git a/arkindex_worker/worker/entity.py b/arkindex_worker/worker/entity.py
index 1689ca03..3f3876e1 100644
--- a/arkindex_worker/worker/entity.py
+++ b/arkindex_worker/worker/entity.py
@@ -5,7 +5,7 @@ from enum import Enum
 from peewee import IntegrityError
 
 from arkindex_worker import logger
-from arkindex_worker.cache import CachedElement, CachedEntity
+from arkindex_worker.cache import CachedElement, CachedEntity, CachedTranscriptionEntity
 from arkindex_worker.models import Element
 
 
@@ -81,3 +81,53 @@ class EntityMixin(object):
                 logger.warning(f"Couldn't save created entity in local cache: {e}")
 
         return entity["id"]
+
+    def create_transcription_entity(self, transcription, entity, offset, length):
+        """
+        Create a link between an existing entity and an existing transcription through API
+        """
+        assert transcription and isinstance(
+            transcription, str
+        ), "transcription shouldn't be null and should be of type str"
+        assert entity and isinstance(
+            entity, str
+        ), "entity shouldn't be null and should be of type str"
+        assert (
+            offset and isinstance(offset, int) and offset >= 0
+        ), "offset shouldn't be null and should be a positive integer"
+        assert (
+            length and isinstance(length, int) and length > 0
+        ), "length shouldn't be null and should be a strictly positive integer"
+        if self.is_read_only:
+            logger.warning(
+                "Cannot create transcription entity as this worker is in read-only mode"
+            )
+            return
+
+        self.request(
+            "CreateTranscriptionEntity",
+            id=transcription,
+            body={
+                "entity": entity,
+                "length": length,
+                "offset": offset,
+            },
+        )
+        # TODO: Report transcription entity creation
+
+        if self.use_cache:
+            # Store transcription entity in local cache
+            try:
+                to_insert = [
+                    {
+                        "transcription": transcription,
+                        "entity": entity,
+                        "offset": offset,
+                        "length": length,
+                    }
+                ]
+                CachedTranscriptionEntity.insert_many(to_insert).execute()
+            except IntegrityError as e:
+                logger.warning(
+                    f"Couldn't save created transcription entity in local cache: {e}"
+                )
diff --git a/tests/test_cache.py b/tests/test_cache.py
index bc26c000..ce4ad7f0 100644
--- a/tests/test_cache.py
+++ b/tests/test_cache.py
@@ -58,6 +58,7 @@ def test_create_tables(tmp_path):
 CREATE TABLE "elements" ("id" TEXT NOT NULL PRIMARY KEY, "parent_id" TEXT, "type" VARCHAR(50) NOT NULL, "image_id" TEXT, "polygon" text, "initial" INTEGER NOT NULL, "worker_version_id" TEXT, FOREIGN KEY ("image_id") REFERENCES "images" ("id"))
 CREATE TABLE "entities" ("id" TEXT NOT NULL PRIMARY KEY, "type" VARCHAR(50) NOT NULL, "name" TEXT NOT NULL, "validated" INTEGER NOT NULL, "metas" text, "worker_version_id" TEXT NOT NULL)
 CREATE TABLE "images" ("id" TEXT NOT NULL PRIMARY KEY, "width" INTEGER NOT NULL, "height" INTEGER NOT NULL, "url" TEXT NOT NULL)
+CREATE TABLE "transcription_entities" ("transcription_id" TEXT NOT NULL, "entity_id" TEXT NOT NULL, "offset" INTEGER NOT NULL CHECK (offset >= 0), "length" INTEGER NOT NULL CHECK (length > 0), PRIMARY KEY ("transcription_id", "entity_id"), FOREIGN KEY ("transcription_id") REFERENCES "transcriptions" ("id"), FOREIGN KEY ("entity_id") REFERENCES "entities" ("id"))
 CREATE TABLE "transcriptions" ("id" TEXT NOT NULL PRIMARY KEY, "element_id" TEXT NOT NULL, "text" TEXT NOT NULL, "confidence" REAL NOT NULL, "worker_version_id" TEXT NOT NULL, FOREIGN KEY ("element_id") REFERENCES "elements" ("id"))"""
 
     actual_schema = "\n".join(
diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py
index b03edb0f..d32d3cfb 100644
--- a/tests/test_elements_worker/test_entities.py
+++ b/tests/test_elements_worker/test_entities.py
@@ -5,7 +5,12 @@ from uuid import UUID
 import pytest
 from apistar.exceptions import ErrorResponse
 
-from arkindex_worker.cache import CachedElement, CachedEntity
+from arkindex_worker.cache import (
+    CachedElement,
+    CachedEntity,
+    CachedTranscription,
+    CachedTranscriptionEntity,
+)
 from arkindex_worker.models import Element
 from arkindex_worker.worker import EntityType
 
@@ -258,3 +263,232 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache):
             worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
         )
     ]
+
+
+def test_create_transcription_entity_wrong_transcription(mock_elements_worker):
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_transcription_entity(
+            transcription=None,
+            entity="11111111-1111-1111-1111-111111111111",
+            offset=5,
+            length=10,
+        )
+    assert str(e.value) == "transcription shouldn't be null and should be of type str"
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_transcription_entity(
+            transcription=1234,
+            entity="11111111-1111-1111-1111-111111111111",
+            offset=5,
+            length=10,
+        )
+    assert str(e.value) == "transcription shouldn't be null and should be of type str"
+
+
+def test_create_transcription_entity_wrong_entity(mock_elements_worker):
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_transcription_entity(
+            transcription="11111111-1111-1111-1111-111111111111",
+            entity=None,
+            offset=5,
+            length=10,
+        )
+    assert str(e.value) == "entity shouldn't be null and should be of type str"
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_transcription_entity(
+            transcription="11111111-1111-1111-1111-111111111111",
+            entity=1234,
+            offset=5,
+            length=10,
+        )
+    assert str(e.value) == "entity shouldn't be null and should be of type str"
+
+
+def test_create_transcription_entity_wrong_offset(mock_elements_worker):
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_transcription_entity(
+            transcription="11111111-1111-1111-1111-111111111111",
+            entity="11111111-1111-1111-1111-111111111111",
+            offset=None,
+            length=10,
+        )
+    assert str(e.value) == "offset shouldn't be null and should be a positive integer"
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_transcription_entity(
+            transcription="11111111-1111-1111-1111-111111111111",
+            entity="11111111-1111-1111-1111-111111111111",
+            offset="not an int",
+            length=10,
+        )
+    assert str(e.value) == "offset shouldn't be null and should be a positive integer"
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_transcription_entity(
+            transcription="11111111-1111-1111-1111-111111111111",
+            entity="11111111-1111-1111-1111-111111111111",
+            offset=-1,
+            length=10,
+        )
+    assert str(e.value) == "offset shouldn't be null and should be a positive integer"
+
+
+def test_create_transcription_entity_wrong_length(mock_elements_worker):
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_transcription_entity(
+            transcription="11111111-1111-1111-1111-111111111111",
+            entity="11111111-1111-1111-1111-111111111111",
+            offset=5,
+            length=None,
+        )
+    assert (
+        str(e.value)
+        == "length shouldn't be null and should be a strictly positive integer"
+    )
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_transcription_entity(
+            transcription="11111111-1111-1111-1111-111111111111",
+            entity="11111111-1111-1111-1111-111111111111",
+            offset=5,
+            length="not an int",
+        )
+    assert (
+        str(e.value)
+        == "length shouldn't be null and should be a strictly positive integer"
+    )
+
+    with pytest.raises(AssertionError) as e:
+        mock_elements_worker.create_transcription_entity(
+            transcription="11111111-1111-1111-1111-111111111111",
+            entity="11111111-1111-1111-1111-111111111111",
+            offset=5,
+            length=0,
+        )
+    assert (
+        str(e.value)
+        == "length shouldn't be null and should be a strictly positive integer"
+    )
+
+
+def test_create_transcription_entity_api_error(responses, mock_elements_worker):
+    responses.add(
+        responses.POST,
+        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
+        status=500,
+    )
+
+    with pytest.raises(ErrorResponse):
+        mock_elements_worker.create_transcription_entity(
+            transcription="11111111-1111-1111-1111-111111111111",
+            entity="11111111-1111-1111-1111-111111111111",
+            offset=5,
+            length=10,
+        )
+
+    assert len(responses.calls) == 7
+    assert [call.request.url for call in responses.calls] == [
+        "http://testserver/api/v1/user/",
+        "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        # We retry 5 times the API call
+        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
+        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
+        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
+        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
+        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
+    ]
+
+
+def test_create_transcription_entity(responses, mock_elements_worker):
+    responses.add(
+        responses.POST,
+        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
+        status=200,
+        json={
+            "entity": "11111111-1111-1111-1111-111111111111",
+            "offset": 5,
+            "length": 10,
+        },
+    )
+
+    mock_elements_worker.create_transcription_entity(
+        transcription="11111111-1111-1111-1111-111111111111",
+        entity="11111111-1111-1111-1111-111111111111",
+        offset=5,
+        length=10,
+    )
+
+    assert len(responses.calls) == 3
+    assert [call.request.url for call in responses.calls] == [
+        "http://testserver/api/v1/user/",
+        "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
+    ]
+    assert json.loads(responses.calls[2].request.body) == {
+        "entity": "11111111-1111-1111-1111-111111111111",
+        "offset": 5,
+        "length": 10,
+    }
+
+
+def test_create_transcription_entity_with_cache(
+    responses, mock_elements_worker_with_cache
+):
+    CachedElement.create(
+        id=UUID("12341234-1234-1234-1234-123412341234"),
+        type="page",
+    )
+    CachedTranscription.create(
+        id=UUID("11111111-1111-1111-1111-111111111111"),
+        element=UUID("12341234-1234-1234-1234-123412341234"),
+        text="Hello, it's me.",
+        confidence=0.42,
+        worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
+    )
+    CachedEntity.create(
+        id=UUID("11111111-1111-1111-1111-111111111111"),
+        type="person",
+        name="Bob Bob",
+        worker_version_id=UUID("12341234-1234-1234-1234-123412341234"),
+    )
+
+    responses.add(
+        responses.POST,
+        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
+        status=200,
+        json={
+            "entity": "11111111-1111-1111-1111-111111111111",
+            "offset": 5,
+            "length": 10,
+        },
+    )
+
+    mock_elements_worker_with_cache.create_transcription_entity(
+        transcription="11111111-1111-1111-1111-111111111111",
+        entity="11111111-1111-1111-1111-111111111111",
+        offset=5,
+        length=10,
+    )
+
+    assert len(responses.calls) == 3
+    assert [call.request.url for call in responses.calls] == [
+        "http://testserver/api/v1/user/",
+        "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        "http://testserver/api/v1/transcription/11111111-1111-1111-1111-111111111111/entity/",
+    ]
+    assert json.loads(responses.calls[2].request.body) == {
+        "entity": "11111111-1111-1111-1111-111111111111",
+        "offset": 5,
+        "length": 10,
+    }
+
+    # Check that created transcription entity was properly stored in SQLite cache
+    assert list(CachedTranscriptionEntity.select()) == [
+        CachedTranscriptionEntity(
+            transcription=UUID("11111111-1111-1111-1111-111111111111"),
+            entity=UUID("11111111-1111-1111-1111-111111111111"),
+            offset=5,
+            length=10,
+        )
+    ]
-- 
GitLab