Skip to content
Snippets Groups Projects
Commit 7045a22a authored by Eva Bardou's avatar Eva Bardou
Browse files

Add logic to store Transcriptions in cache

parent e2c63f34
No related branches found
No related tags found
No related merge requests found
...@@ -12,6 +12,14 @@ SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements ( ...@@ -12,6 +12,14 @@ SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements (
initial BOOLEAN DEFAULT 0 NOT NULL, initial BOOLEAN DEFAULT 0 NOT NULL,
worker_version_id VARCHAR(32) worker_version_id VARCHAR(32)
)""" )"""
SQL_TRANSCRIPTIONS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS transcriptions (
id VARCHAR(32) PRIMARY KEY,
element_id VARCHAR(32) NOT NULL,
text TEXT NOT NULL,
confidence REAL NOT NULL,
worker_version_id VARCHAR(32) NOT NULL,
FOREIGN KEY(element_id) REFERENCES elements(id)
)"""
CachedElement = namedtuple( CachedElement = namedtuple(
...@@ -19,6 +27,10 @@ CachedElement = namedtuple( ...@@ -19,6 +27,10 @@ CachedElement = namedtuple(
["id", "type", "polygon", "worker_version_id", "parent_id", "initial"], ["id", "type", "polygon", "worker_version_id", "parent_id", "initial"],
defaults=[None, 0], defaults=[None, 0],
) )
CachedTranscription = namedtuple(
"CachedTranscription",
["id", "element_id", "text", "confidence", "worker_version_id"],
)
class LocalDB(object): class LocalDB(object):
...@@ -30,6 +42,7 @@ class LocalDB(object): ...@@ -30,6 +42,7 @@ class LocalDB(object):
def create_tables(self): def create_tables(self):
self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION) self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION)
self.cursor.execute(SQL_TRANSCRIPTIONS_TABLE_CREATION)
def insert(self, table, lines): def insert(self, table, lines):
if not lines: if not lines:
......
...@@ -17,7 +17,7 @@ from apistar.exceptions import ErrorResponse ...@@ -17,7 +17,7 @@ from apistar.exceptions import ErrorResponse
from arkindex import ArkindexClient, options_from_env from arkindex import ArkindexClient, options_from_env
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, LocalDB from arkindex_worker.cache import CachedElement, CachedTranscription, LocalDB
from arkindex_worker.models import Element from arkindex_worker.models import Element
from arkindex_worker.reporting import Reporter from arkindex_worker.reporting import Reporter
from arkindex_worker.utils import convert_str_uuid_to_hex from arkindex_worker.utils import convert_str_uuid_to_hex
...@@ -524,7 +524,7 @@ class ElementsWorker(BaseWorker): ...@@ -524,7 +524,7 @@ class ElementsWorker(BaseWorker):
) )
return return
self.api_client.request( created = self.api_client.request(
"CreateTranscription", "CreateTranscription",
id=element.id, id=element.id,
body={ body={
...@@ -533,8 +533,24 @@ class ElementsWorker(BaseWorker): ...@@ -533,8 +533,24 @@ class ElementsWorker(BaseWorker):
"score": score, "score": score,
}, },
) )
self.report.add_transcription(element.id) self.report.add_transcription(element.id)
# Store transcription in local cache
try:
to_insert = [
CachedTranscription(
id=convert_str_uuid_to_hex(created["id"]),
element_id=convert_str_uuid_to_hex(element.id),
text=created["text"],
confidence=created["confidence"],
worker_version_id=convert_str_uuid_to_hex(self.worker_version_id),
)
]
self.cache.insert("transcriptions", to_insert)
except sqlite3.IntegrityError as e:
logger.warning(f"Couldn't save created transcription in local cache: {e}")
def create_classification( def create_classification(
self, element, ml_class, confidence, high_confidence=False self, element, ml_class, confidence, high_confidence=False
): ):
...@@ -703,14 +719,55 @@ class ElementsWorker(BaseWorker): ...@@ -703,14 +719,55 @@ class ElementsWorker(BaseWorker):
"return_elements": True, "return_elements": True,
}, },
) )
for annotation in annotations:
created_ids = []
elements_to_insert = []
transcriptions_to_insert = []
parent_id_hex = convert_str_uuid_to_hex(element.id)
worker_version_id_hex = convert_str_uuid_to_hex(self.worker_version_id)
for index, annotation in enumerate(annotations):
transcription = transcriptions[index]
element_id_hex = convert_str_uuid_to_hex(annotation["id"])
if annotation["created"]: if annotation["created"]:
logger.debug( logger.debug(
f"A sub_element of {element.id} with type {sub_element_type} was created during transcriptions bulk creation" f"A sub_element of {element.id} with type {sub_element_type} was created during transcriptions bulk creation"
) )
self.report.add_element(element.id, sub_element_type) self.report.add_element(element.id, sub_element_type)
if annotation["id"] not in created_ids:
# TODO: Retrieve real element_name through API
elements_to_insert.append(
CachedElement(
id=element_id_hex,
parent_id=parent_id_hex,
name="test",
type=sub_element_type,
polygon=json.dumps(transcription["polygon"]),
worker_version_id=worker_version_id_hex,
)
)
created_ids.append(annotation["id"])
self.report.add_transcription(annotation["id"]) self.report.add_transcription(annotation["id"])
transcriptions_to_insert.append(
CachedTranscription(
# TODO: Retrieve real transcription_id through API
id=convert_str_uuid_to_hex(uuid.uuid4()),
element_id=element_id_hex,
text=transcription["text"],
confidence=transcription["score"],
worker_version_id=worker_version_id_hex,
)
)
# Store transcriptions and their associated element (if created) in local cache
try:
self.cache.insert("elements", elements_to_insert)
self.cache.insert("transcriptions", transcriptions_to_insert)
except sqlite3.IntegrityError as e:
logger.warning(f"Couldn't save created transcriptions in local cache: {e}")
return annotations return annotations
def create_metadata(self, element, type, name, value, entity=None): def create_metadata(self, element, type, name, value, entity=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment