From 34d13098096a627a111c1c7462d2da9d0469bdc8 Mon Sep 17 00:00:00 2001
From: Eva Bardou <ebardou@teklia.com>
Date: Tue, 23 Mar 2021 16:32:13 +0100
Subject: [PATCH] Add logic to store Transcriptions in cache

---
 arkindex_worker/cache.py  | 13 ++++++++
 arkindex_worker/worker.py | 63 +++++++++++++++++++++++++++++++++++++--
 2 files changed, 73 insertions(+), 3 deletions(-)

diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py
index 2a36c219..22930be1 100644
--- a/arkindex_worker/cache.py
+++ b/arkindex_worker/cache.py
@@ -13,6 +13,14 @@ SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements (
     initial BOOLEAN DEFAULT 0 NOT NULL,
     worker_version_id VARCHAR(32)
 )"""
+SQL_TRANSCRIPTIONS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS transcriptions (
+    id VARCHAR(32) PRIMARY KEY,
+    element_id VARCHAR(32) NOT NULL,
+    text TEXT NOT NULL,
+    confidence REAL NOT NULL,
+    worker_version_id VARCHAR(32) NOT NULL,
+    FOREIGN KEY(element_id) REFERENCES elements(id)
+)"""
 
 
 CachedElement = namedtuple(
@@ -20,6 +28,10 @@ CachedElement = namedtuple(
     ["id", "name", "type", "polygon", "worker_version_id", "parent_id", "initial"],
     defaults=[None, 0],
 )
+CachedTranscription = namedtuple(
+    "CachedTranscription",
+    ["id", "element_id", "text", "confidence", "worker_version_id"],
+)
 
 
 class LocalDB(object):
@@ -31,6 +43,7 @@ class LocalDB(object):
 
     def create_tables(self):
         self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION)
+        self.cursor.execute(SQL_TRANSCRIPTIONS_TABLE_CREATION)
 
     def insert(self, table, lines):
         if not lines:
diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py
index 839ee99f..d25badd5 100644
--- a/arkindex_worker/worker.py
+++ b/arkindex_worker/worker.py
@@ -17,7 +17,7 @@ from apistar.exceptions import ErrorResponse
 
 from arkindex import ArkindexClient, options_from_env
 from arkindex_worker import logger
-from arkindex_worker.cache import CachedElement, LocalDB
+from arkindex_worker.cache import CachedElement, CachedTranscription, LocalDB
 from arkindex_worker.models import Element
 from arkindex_worker.reporting import Reporter
 from arkindex_worker.utils import convert_str_uuid_to_hex
@@ -524,7 +524,7 @@ class ElementsWorker(BaseWorker):
             )
             return
 
-        self.api_client.request(
+        created = self.api_client.request(
             "CreateTranscription",
             id=element.id,
             body={
@@ -533,8 +533,24 @@ class ElementsWorker(BaseWorker):
                 "score": score,
             },
         )
+
         self.report.add_transcription(element.id)
 
+        # Store transcription in local cache
+        try:
+            to_insert = [
+                CachedTranscription(
+                    id=convert_str_uuid_to_hex(created["id"]),
+                    element_id=convert_str_uuid_to_hex(element.id),
+                    text=created["text"],
+                    confidence=created["confidence"],
+                    worker_version_id=convert_str_uuid_to_hex(self.worker_version_id),
+                )
+            ]
+            self.cache.insert("transcriptions", to_insert)
+        except sqlite3.IntegrityError as e:
+            logger.warning(f"Couldn't save created transcription in local cache: {e}")
+
     def create_classification(
         self, element, ml_class, confidence, high_confidence=False
     ):
@@ -703,14 +719,55 @@ class ElementsWorker(BaseWorker):
                 "return_elements": True,
             },
         )
-        for annotation in annotations:
+
+        created_ids = []
+        elements_to_insert = []
+        transcriptions_to_insert = []
+        parent_id_hex = convert_str_uuid_to_hex(element.id)
+        worker_version_id_hex = convert_str_uuid_to_hex(self.worker_version_id)
+        for index, annotation in enumerate(annotations):
+            transcription = transcriptions[index]
+            element_id_hex = convert_str_uuid_to_hex(annotation["id"])
             if annotation["created"]:
                 logger.debug(
                     f"A sub_element of {element.id} with type {sub_element_type} was created during transcriptions bulk creation"
                 )
                 self.report.add_element(element.id, sub_element_type)
+
+                if annotation["id"] not in created_ids:
+                    # TODO: Retrieve real element_name through API
+                    elements_to_insert.append(
+                        CachedElement(
+                            id=element_id_hex,
+                            parent_id=parent_id_hex,
+                            name="test",
+                            type=sub_element_type,
+                            polygon=json.dumps(transcription["polygon"]),
+                            worker_version_id=worker_version_id_hex,
+                        )
+                    )
+                    created_ids.append(annotation["id"])
+
             self.report.add_transcription(annotation["id"])
 
+            transcriptions_to_insert.append(
+                CachedTranscription(
+                    # TODO: Retrieve real transcription_id through API
+                    id=convert_str_uuid_to_hex(uuid.uuid4()),
+                    element_id=element_id_hex,
+                    text=transcription["text"],
+                    confidence=transcription["score"],
+                    worker_version_id=worker_version_id_hex,
+                )
+            )
+
+        # Store transcriptions and their associated element (if created) in local cache
+        try:
+            self.cache.insert("elements", elements_to_insert)
+            self.cache.insert("transcriptions", transcriptions_to_insert)
+        except sqlite3.IntegrityError as e:
+            logger.warning(f"Couldn't save created transcriptions in local cache: {e}")
+
         return annotations
 
     def create_metadata(self, element, type, name, value, entity=None):
-- 
GitLab