From 34d13098096a627a111c1c7462d2da9d0469bdc8 Mon Sep 17 00:00:00 2001 From: Eva Bardou <ebardou@teklia.com> Date: Tue, 23 Mar 2021 16:32:13 +0100 Subject: [PATCH] Add logic to store Transcriptions in cache --- arkindex_worker/cache.py | 13 ++++++++ arkindex_worker/worker.py | 63 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 2a36c219..22930be1 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -13,6 +13,14 @@ SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements ( initial BOOLEAN DEFAULT 0 NOT NULL, worker_version_id VARCHAR(32) )""" +SQL_TRANSCRIPTIONS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS transcriptions ( + id VARCHAR(32) PRIMARY KEY, + element_id VARCHAR(32) NOT NULL, + text TEXT NOT NULL, + confidence REAL NOT NULL, + worker_version_id VARCHAR(32) NOT NULL, + FOREIGN KEY(element_id) REFERENCES elements(id) +)""" CachedElement = namedtuple( @@ -20,6 +28,10 @@ CachedElement = namedtuple( ["id", "name", "type", "polygon", "worker_version_id", "parent_id", "initial"], defaults=[None, 0], ) +CachedTranscription = namedtuple( + "CachedTranscription", + ["id", "element_id", "text", "confidence", "worker_version_id"], +) class LocalDB(object): @@ -31,6 +43,7 @@ class LocalDB(object): def create_tables(self): self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION) + self.cursor.execute(SQL_TRANSCRIPTIONS_TABLE_CREATION) def insert(self, table, lines): if not lines: diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index 839ee99f..d25badd5 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -17,7 +17,7 @@ from apistar.exceptions import ErrorResponse from arkindex import ArkindexClient, options_from_env from arkindex_worker import logger -from arkindex_worker.cache import CachedElement, LocalDB +from arkindex_worker.cache import CachedElement, CachedTranscription, LocalDB from arkindex_worker.models import Element from arkindex_worker.reporting import Reporter from arkindex_worker.utils import convert_str_uuid_to_hex @@ -524,7 +524,7 @@ class ElementsWorker(BaseWorker): ) return - self.api_client.request( + created = self.api_client.request( "CreateTranscription", id=element.id, body={ @@ -533,8 +533,24 @@ class ElementsWorker(BaseWorker): "score": score, }, ) + self.report.add_transcription(element.id) + # Store transcription in local cache + try: + to_insert = [ + CachedTranscription( + id=convert_str_uuid_to_hex(created["id"]), + element_id=convert_str_uuid_to_hex(element.id), + text=created["text"], + confidence=created["confidence"], + worker_version_id=convert_str_uuid_to_hex(self.worker_version_id), + ) + ] + self.cache.insert("transcriptions", to_insert) + except sqlite3.IntegrityError as e: + logger.warning(f"Couldn't save created transcription in local cache: {e}") + def create_classification( self, element, ml_class, confidence, high_confidence=False ): @@ -703,14 +719,55 @@ class ElementsWorker(BaseWorker): "return_elements": True, }, ) - for annotation in annotations: + + created_ids = [] + elements_to_insert = [] + transcriptions_to_insert = [] + parent_id_hex = convert_str_uuid_to_hex(element.id) + worker_version_id_hex = convert_str_uuid_to_hex(self.worker_version_id) + for index, annotation in enumerate(annotations): + transcription = transcriptions[index] + element_id_hex = convert_str_uuid_to_hex(annotation["id"]) if annotation["created"]: logger.debug( f"A sub_element of {element.id} with type {sub_element_type} was created during transcriptions bulk creation" ) self.report.add_element(element.id, sub_element_type) + + if annotation["id"] not in created_ids: + # TODO: Retrieve real element_name through API + elements_to_insert.append( + CachedElement( + id=element_id_hex, + parent_id=parent_id_hex, + name="test", + type=sub_element_type, + polygon=json.dumps(transcription["polygon"]), + worker_version_id=worker_version_id_hex, + ) + ) + created_ids.append(annotation["id"]) + self.report.add_transcription(annotation["id"]) + transcriptions_to_insert.append( + CachedTranscription( + # TODO: Retrieve real transcription_id through API + id=convert_str_uuid_to_hex(uuid.uuid4()), + element_id=element_id_hex, + text=transcription["text"], + confidence=transcription["score"], + worker_version_id=worker_version_id_hex, + ) + ) + + # Store transcriptions and their associated element (if created) in local cache + try: + self.cache.insert("elements", elements_to_insert) + self.cache.insert("transcriptions", transcriptions_to_insert) + except sqlite3.IntegrityError as e: + logger.warning(f"Couldn't save created transcriptions in local cache: {e}") + return annotations def create_metadata(self, element, type, name, value, entity=None): -- GitLab