diff --git a/arkindex_worker/cache.py b/arkindex_worker/cache.py index 4c45d1349e7d429fb6435126cfc70b97b834d4ce..c67d1d9e0e16880f76e237bf2ef4cda107ad763f 100644 --- a/arkindex_worker/cache.py +++ b/arkindex_worker/cache.py @@ -12,6 +12,14 @@ SQL_ELEMENTS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS elements ( initial BOOLEAN DEFAULT 0 NOT NULL, worker_version_id VARCHAR(32) )""" +SQL_TRANSCRIPTIONS_TABLE_CREATION = """CREATE TABLE IF NOT EXISTS transcriptions ( + id VARCHAR(32) PRIMARY KEY, + element_id VARCHAR(32) NOT NULL, + text TEXT NOT NULL, + confidence REAL NOT NULL, + worker_version_id VARCHAR(32) NOT NULL, + FOREIGN KEY(element_id) REFERENCES elements(id) +)""" CachedElement = namedtuple( @@ -19,6 +27,10 @@ CachedElement = namedtuple( ["id", "type", "polygon", "worker_version_id", "parent_id", "initial"], defaults=[None, 0], ) +CachedTranscription = namedtuple( + "CachedTranscription", + ["id", "element_id", "text", "confidence", "worker_version_id"], +) def convert_table_tuple(table): @@ -37,6 +49,7 @@ class LocalDB(object): def create_tables(self): self.cursor.execute(SQL_ELEMENTS_TABLE_CREATION) + self.cursor.execute(SQL_TRANSCRIPTIONS_TABLE_CREATION) def insert(self, table, lines): if not lines: diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index 19ae2a7fc7dd3cc640e3679697e70362a49e4d82..bddb65e038bf8633ead7119591e12d2962ea5032 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -16,7 +16,7 @@ from apistar.exceptions import ErrorResponse from arkindex import ArkindexClient, options_from_env from arkindex_worker import logger -from arkindex_worker.cache import CachedElement, LocalDB +from arkindex_worker.cache import CachedElement, CachedTranscription, LocalDB from arkindex_worker.models import Element from arkindex_worker.reporting import Reporter from arkindex_worker.utils import convert_str_uuid_to_hex @@ -500,7 +500,7 @@ class ElementsWorker(BaseWorker): ) return - self.api_client.request( + created = self.api_client.request( "CreateTranscription", id=element.id, body={ @@ -509,8 +509,29 @@ class ElementsWorker(BaseWorker): "score": score, }, ) + self.report.add_transcription(element.id) + if self.cache: + # Store transcription in local cache + try: + to_insert = [ + CachedTranscription( + id=convert_str_uuid_to_hex(created["id"]), + element_id=convert_str_uuid_to_hex(element.id), + text=created["text"], + confidence=created["confidence"], + worker_version_id=convert_str_uuid_to_hex( + self.worker_version_id + ), + ) + ] + self.cache.insert("transcriptions", to_insert) + except sqlite3.IntegrityError as e: + logger.warning( + f"Couldn't save created transcription in local cache: {e}" + ) + def create_classification( self, element, ml_class, confidence, high_confidence=False ): @@ -659,6 +680,7 @@ class ElementsWorker(BaseWorker): "return_elements": True, }, ) + for annotation in annotations: if annotation["created"]: logger.debug( @@ -667,6 +689,47 @@ class ElementsWorker(BaseWorker): self.report.add_element(element.id, sub_element_type) self.report.add_transcription(annotation["id"]) + if self.cache: + # Store transcriptions and their associated element (if created) in local cache + created_ids = [] + elements_to_insert = [] + transcriptions_to_insert = [] + parent_id_hex = convert_str_uuid_to_hex(element.id) + worker_version_id_hex = convert_str_uuid_to_hex(self.worker_version_id) + for index, annotation in enumerate(annotations): + transcription = transcriptions[index] + element_id_hex = convert_str_uuid_to_hex(annotation["id"]) + if annotation["created"] and annotation["id"] not in created_ids: + elements_to_insert.append( + CachedElement( + id=element_id_hex, + parent_id=parent_id_hex, + type=sub_element_type, + polygon=json.dumps(transcription["polygon"]), + worker_version_id=worker_version_id_hex, + ) + ) + created_ids.append(annotation["id"]) + + transcriptions_to_insert.append( + CachedTranscription( + # TODO: Retrieve real transcription_id through API + id=convert_str_uuid_to_hex(str(uuid.uuid4())), + element_id=element_id_hex, + text=transcription["text"], + confidence=transcription["score"], + worker_version_id=worker_version_id_hex, + ) + ) + + try: + self.cache.insert("elements", elements_to_insert) + self.cache.insert("transcriptions", transcriptions_to_insert) + except sqlite3.IntegrityError as e: + logger.warning( + f"Couldn't save created transcriptions in local cache: {e}" + ) + return annotations def create_metadata(self, element, type, name, value, entity=None): diff --git a/tests/data/cache/lines.sqlite b/tests/data/cache/lines.sqlite index ea881e4f1bb2143f560ca81ac6435842494c95f5..d7c476d4c629e98913b2c1f4edf4f66cf3f70fe5 100644 Binary files a/tests/data/cache/lines.sqlite and b/tests/data/cache/lines.sqlite differ diff --git a/tests/data/cache/tables.sqlite b/tests/data/cache/tables.sqlite index efc107cd7506757f92df1ee011b670738228b8b2..f8027fdfbd148d0240047ea2facab93bb2ced474 100644 Binary files a/tests/data/cache/tables.sqlite and b/tests/data/cache/tables.sqlite differ diff --git a/tests/test_cache.py b/tests/test_cache.py index a978e3632ab165d32e49b86ba02230f7b3d2e7e3..8bfc698a1ae1d4fd8e8b3323712d32ad775b7e6a 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -6,7 +6,7 @@ from pathlib import Path import pytest -from arkindex_worker.cache import CachedElement, LocalDB +from arkindex_worker.cache import CachedElement, CachedTranscription, LocalDB from arkindex_worker.utils import convert_str_uuid_to_hex FIXTURES = Path(__file__).absolute().parent / "data/cache" @@ -30,6 +30,26 @@ ELEMENTS_TO_INSERT = [ ), ), ] +TRANSCRIPTIONS_TO_INSERT = [ + CachedTranscription( + id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + text="Hello!", + confidence=0.42, + worker_version_id=convert_str_uuid_to_hex( + "56785678-5678-5678-5678-567856785678" + ), + ), + CachedTranscription( + id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"), + element_id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"), + text="How are you?", + confidence=0.42, + worker_version_id=convert_str_uuid_to_hex( + "56785678-5678-5678-5678-567856785678" + ), + ), +] def test_init_non_existent_path(): @@ -108,6 +128,10 @@ def test_insert_existing_lines(): cache.insert("elements", ELEMENTS_TO_INSERT) assert str(e.value) == "UNIQUE constraint failed: elements.id" + with pytest.raises(sqlite3.IntegrityError) as e: + cache.insert("transcriptions", TRANSCRIPTIONS_TO_INSERT) + assert str(e.value) == "UNIQUE constraint failed: transcriptions.id" + with open(db_path, "rb") as after_file: after = after_file.read() @@ -129,6 +153,19 @@ def test_insert(): assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT + cache.insert("transcriptions", TRANSCRIPTIONS_TO_INSERT) + generated_rows = cache.cursor.execute("SELECT * FROM transcriptions").fetchall() + + expected_cache = LocalDB(f"{FIXTURES}/lines.sqlite") + assert ( + generated_rows + == expected_cache.cursor.execute("SELECT * FROM transcriptions").fetchall() + ) + + assert [ + CachedTranscription(**dict(row)) for row in generated_rows + ] == TRANSCRIPTIONS_TO_INSERT + def test_fetch_all(): db_path = f"{FIXTURES}/lines.sqlite" diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py index 3d603fcb7f6e548f6398ee9330a47e75194ff2a7..7d350e3558135638115d1e3a7ce382bdad9be287 100644 --- a/tests/test_elements_worker/test_transcriptions.py +++ b/tests/test_elements_worker/test_transcriptions.py @@ -1,11 +1,16 @@ # -*- coding: utf-8 -*- import json +import os +from pathlib import Path import pytest from apistar.exceptions import ErrorResponse +from arkindex_worker.cache import CachedElement, CachedTranscription from arkindex_worker.models import Element +from arkindex_worker.utils import convert_str_uuid_to_hex +CACHE_DIR = Path(__file__).absolute().parent.parent / "data/cache" TRANSCRIPTIONS_SAMPLE = [ { "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]], @@ -130,15 +135,22 @@ def test_create_transcription_api_error(responses, mock_elements_worker): ] -def test_create_transcription(responses, mock_elements_worker): +def test_create_transcription(responses, mock_elements_worker_with_cache): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, f"http://testserver/api/v1/element/{elt.id}/transcription/", status=200, + json={ + "id": "56785678-5678-5678-5678-567856785678", + "text": "i am a line", + "score": 0.42, + "confidence": 0.42, + "worker_version_id": "12341234-1234-1234-1234-123412341234", + }, ) - mock_elements_worker.create_transcription( + mock_elements_worker_with_cache.create_transcription( element=elt, text="i am a line", score=0.42, @@ -157,6 +169,25 @@ def test_create_transcription(responses, mock_elements_worker): "score": 0.42, } + # Check that created transcription was properly stored in SQLite cache + cache_path = f"{CACHE_DIR}/db.sqlite" + assert os.path.isfile(cache_path) + + rows = mock_elements_worker_with_cache.cache.cursor.execute( + "SELECT * FROM transcriptions" + ).fetchall() + assert [CachedTranscription(**dict(row)) for row in rows] == [ + CachedTranscription( + id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"), + element_id=convert_str_uuid_to_hex(elt.id), + text="i am a line", + confidence=0.42, + worker_version_id=convert_str_uuid_to_hex( + "12341234-1234-1234-1234-123412341234" + ), + ) + ] + def test_create_element_transcriptions_wrong_element(mock_elements_worker): with pytest.raises(AssertionError) as e: @@ -551,20 +582,30 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker ] -def test_create_element_transcriptions(responses, mock_elements_worker): +def test_create_element_transcriptions( + mocker, responses, mock_elements_worker_with_cache +): + mocker.patch( + "uuid.uuid4", + side_effect=[ + "56785678-5678-5678-5678-567856785678", + "67896789-6789-6789-6789-678967896789", + "78907890-7890-7890-7890-789078907890", + ], + ) elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/", status=200, json=[ - {"id": "word1_1_1", "created": False}, - {"id": "word1_1_2", "created": False}, - {"id": "word1_1_3", "created": False}, + {"id": "11111111-1111-1111-1111-111111111111", "created": True}, + {"id": "22222222-2222-2222-2222-222222222222", "created": False}, + {"id": "11111111-1111-1111-1111-111111111111", "created": True}, ], ) - annotations = mock_elements_worker.create_element_transcriptions( + annotations = mock_elements_worker_with_cache.create_element_transcriptions( element=elt, sub_element_type="page", transcriptions=TRANSCRIPTIONS_SAMPLE, @@ -584,9 +625,60 @@ def test_create_element_transcriptions(responses, mock_elements_worker): "return_elements": True, } assert annotations == [ - {"id": "word1_1_1", "created": False}, - {"id": "word1_1_2", "created": False}, - {"id": "word1_1_3", "created": False}, + {"id": "11111111-1111-1111-1111-111111111111", "created": True}, + {"id": "22222222-2222-2222-2222-222222222222", "created": False}, + {"id": "11111111-1111-1111-1111-111111111111", "created": True}, + ] + + # Check that created transcriptions and elements were properly stored in SQLite cache + cache_path = f"{CACHE_DIR}/db.sqlite" + assert os.path.isfile(cache_path) + + rows = mock_elements_worker_with_cache.cache.cursor.execute( + "SELECT * FROM elements" + ).fetchall() + assert [CachedElement(**dict(row)) for row in rows] == [ + CachedElement( + id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + parent_id=convert_str_uuid_to_hex("12341234-1234-1234-1234-123412341234"), + type="page", + polygon=json.dumps([[100, 150], [700, 150], [700, 200], [100, 200]]), + worker_version_id=convert_str_uuid_to_hex( + "12341234-1234-1234-1234-123412341234" + ), + ) + ] + rows = mock_elements_worker_with_cache.cache.cursor.execute( + "SELECT * FROM transcriptions" + ).fetchall() + assert [CachedTranscription(**dict(row)) for row in rows] == [ + CachedTranscription( + id=convert_str_uuid_to_hex("56785678-5678-5678-5678-567856785678"), + element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + text="The", + confidence=0.5, + worker_version_id=convert_str_uuid_to_hex( + "12341234-1234-1234-1234-123412341234" + ), + ), + CachedTranscription( + id=convert_str_uuid_to_hex("67896789-6789-6789-6789-678967896789"), + element_id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"), + text="first", + confidence=0.75, + worker_version_id=convert_str_uuid_to_hex( + "12341234-1234-1234-1234-123412341234" + ), + ), + CachedTranscription( + id=convert_str_uuid_to_hex("78907890-7890-7890-7890-789078907890"), + element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"), + text="line", + confidence=0.9, + worker_version_id=convert_str_uuid_to_hex( + "12341234-1234-1234-1234-123412341234" + ), + ), ]