Skip to content
Snippets Groups Projects
Commit f6a16bfa authored by Eva Bardou's avatar Eva Bardou
Browse files

Add logic to store Transcriptions in cache

parent 32459852
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,14 @@ SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements (
initial BOOLEAN DEFAULT 0 NOT NULL,
worker_version_id VARCHAR(32)
)"""
SQL_TRANSCRIPTIONS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS transcriptions (
id VARCHAR(32) PRIMARY KEY,
element_id VARCHAR(32) NOT NULL,
text TEXT NOT NULL,
confidence REAL NOT NULL,
worker_version_id VARCHAR(32) NOT NULL,
FOREIGN KEY(element_id) REFERENCES elements(id)
)"""
CachedElement = namedtuple(
......@@ -19,6 +27,10 @@ CachedElement = namedtuple(
["id", "type", "polygon", "worker_version_id", "parent_id", "initial"],
defaults=[None, 0],
)
CachedTranscription = namedtuple(
"CachedTranscription",
["id", "element_id", "text", "confidence", "worker_version_id"],
)
class LocalDB(object):
......@@ -30,6 +42,7 @@ class LocalDB(object):
def create_tables(self):
self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION)
self.cursor.execute(SQL_TRANSCRIPTIONS_TABLE_CREATION)
def insert(self, table, lines):
if not lines:
......
......@@ -16,7 +16,7 @@ from apistar.exceptions import ErrorResponse
from arkindex import ArkindexClient, options_from_env
from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, LocalDB
from arkindex_worker.cache import CachedElement, CachedTranscription, LocalDB
from arkindex_worker.models import Element
from arkindex_worker.reporting import Reporter
from arkindex_worker.utils import convert_str_uuid_to_hex
......@@ -500,7 +500,7 @@ class ElementsWorker(BaseWorker):
)
return
self.api_client.request(
created = self.api_client.request(
"CreateTranscription",
id=element.id,
body={
......@@ -509,8 +509,24 @@ class ElementsWorker(BaseWorker):
"score": score,
},
)
self.report.add_transcription(element.id)
# Store transcription in local cache
try:
to_insert = [
CachedTranscription(
id=convert_str_uuid_to_hex(created["id"]),
element_id=convert_str_uuid_to_hex(element.id),
text=created["text"],
confidence=created["confidence"],
worker_version_id=convert_str_uuid_to_hex(self.worker_version_id),
)
]
self.cache.insert("transcriptions", to_insert)
except sqlite3.IntegrityError as e:
logger.warning(f"Couldn't save created transcription in local cache: {e}")
def create_classification(
self, element, ml_class, confidence, high_confidence=False
):
......@@ -659,14 +675,55 @@ class ElementsWorker(BaseWorker):
"return_elements": True,
},
)
for annotation in annotations:
created_ids = []
elements_to_insert = []
transcriptions_to_insert = []
parent_id_hex = convert_str_uuid_to_hex(element.id)
worker_version_id_hex = convert_str_uuid_to_hex(self.worker_version_id)
for index, annotation in enumerate(annotations):
transcription = transcriptions[index]
element_id_hex = convert_str_uuid_to_hex(annotation["id"])
if annotation["created"]:
logger.debug(
f"A sub_element of {element.id} with type {sub_element_type} was created during transcriptions bulk creation"
)
self.report.add_element(element.id, sub_element_type)
if annotation["id"] not in created_ids:
# TODO: Retrieve real element_name through API
elements_to_insert.append(
CachedElement(
id=element_id_hex,
parent_id=parent_id_hex,
name="test",
type=sub_element_type,
polygon=json.dumps(transcription["polygon"]),
worker_version_id=worker_version_id_hex,
)
)
created_ids.append(annotation["id"])
self.report.add_transcription(annotation["id"])
transcriptions_to_insert.append(
CachedTranscription(
# TODO: Retrieve real transcription_id through API
id=convert_str_uuid_to_hex(uuid.uuid4()),
element_id=element_id_hex,
text=transcription["text"],
confidence=transcription["score"],
worker_version_id=worker_version_id_hex,
)
)
# Store transcriptions and their associated element (if created) in local cache
try:
self.cache.insert("elements", elements_to_insert)
self.cache.insert("transcriptions", transcriptions_to_insert)
except sqlite3.IntegrityError as e:
logger.warning(f"Couldn't save created transcriptions in local cache: {e}")
return annotations
def create_metadata(self, element, type, name, value, entity=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment