Skip to content
Snippets Groups Projects
Commit f8e167d4 authored by Eva Bardou's avatar Eva Bardou
Browse files

Fix existing tests

parent f6a16bfa
No related branches found
No related tags found
No related merge requests found
...@@ -512,20 +512,25 @@ class ElementsWorker(BaseWorker): ...@@ -512,20 +512,25 @@ class ElementsWorker(BaseWorker):
self.report.add_transcription(element.id) self.report.add_transcription(element.id)
# Store transcription in local cache if self.cache:
try: # Store transcription in local cache
to_insert = [ try:
CachedTranscription( to_insert = [
id=convert_str_uuid_to_hex(created["id"]), CachedTranscription(
element_id=convert_str_uuid_to_hex(element.id), id=convert_str_uuid_to_hex(created["id"]),
text=created["text"], element_id=convert_str_uuid_to_hex(element.id),
confidence=created["confidence"], text=created["text"],
worker_version_id=convert_str_uuid_to_hex(self.worker_version_id), confidence=created["confidence"],
worker_version_id=convert_str_uuid_to_hex(
self.worker_version_id
),
)
]
self.cache.insert("transcriptions", to_insert)
except sqlite3.IntegrityError as e:
logger.warning(
f"Couldn't save created transcription in local cache: {e}"
) )
]
self.cache.insert("transcriptions", to_insert)
except sqlite3.IntegrityError as e:
logger.warning(f"Couldn't save created transcription in local cache: {e}")
def create_classification( def create_classification(
self, element, ml_class, confidence, high_confidence=False self, element, ml_class, confidence, high_confidence=False
...@@ -676,27 +681,30 @@ class ElementsWorker(BaseWorker): ...@@ -676,27 +681,30 @@ class ElementsWorker(BaseWorker):
}, },
) )
created_ids = [] for annotation in annotations:
elements_to_insert = []
transcriptions_to_insert = []
parent_id_hex = convert_str_uuid_to_hex(element.id)
worker_version_id_hex = convert_str_uuid_to_hex(self.worker_version_id)
for index, annotation in enumerate(annotations):
transcription = transcriptions[index]
element_id_hex = convert_str_uuid_to_hex(annotation["id"])
if annotation["created"]: if annotation["created"]:
logger.debug( logger.debug(
f"A sub_element of {element.id} with type {sub_element_type} was created during transcriptions bulk creation" f"A sub_element of {element.id} with type {sub_element_type} was created during transcriptions bulk creation"
) )
self.report.add_element(element.id, sub_element_type) self.report.add_element(element.id, sub_element_type)
self.report.add_transcription(annotation["id"])
if annotation["id"] not in created_ids: if self.cache:
# Store transcriptions and their associated element (if created) in local cache
created_ids = []
elements_to_insert = []
transcriptions_to_insert = []
parent_id_hex = convert_str_uuid_to_hex(element.id)
worker_version_id_hex = convert_str_uuid_to_hex(self.worker_version_id)
for index, annotation in enumerate(annotations):
transcription = transcriptions[index]
element_id_hex = convert_str_uuid_to_hex(annotation["id"])
if annotation["created"] and annotation["id"] not in created_ids:
# TODO: Retrieve real element_name through API # TODO: Retrieve real element_name through API
elements_to_insert.append( elements_to_insert.append(
CachedElement( CachedElement(
id=element_id_hex, id=element_id_hex,
parent_id=parent_id_hex, parent_id=parent_id_hex,
name="test",
type=sub_element_type, type=sub_element_type,
polygon=json.dumps(transcription["polygon"]), polygon=json.dumps(transcription["polygon"]),
worker_version_id=worker_version_id_hex, worker_version_id=worker_version_id_hex,
...@@ -704,25 +712,24 @@ class ElementsWorker(BaseWorker): ...@@ -704,25 +712,24 @@ class ElementsWorker(BaseWorker):
) )
created_ids.append(annotation["id"]) created_ids.append(annotation["id"])
self.report.add_transcription(annotation["id"]) transcriptions_to_insert.append(
CachedTranscription(
transcriptions_to_insert.append( # TODO: Retrieve real transcription_id through API
CachedTranscription( id=convert_str_uuid_to_hex(uuid.uuid4()),
# TODO: Retrieve real transcription_id through API element_id=element_id_hex,
id=convert_str_uuid_to_hex(uuid.uuid4()), text=transcription["text"],
element_id=element_id_hex, confidence=transcription["score"],
text=transcription["text"], worker_version_id=worker_version_id_hex,
confidence=transcription["score"], )
worker_version_id=worker_version_id_hex,
) )
)
# Store transcriptions and their associated element (if created) in local cache try:
try: self.cache.insert("elements", elements_to_insert)
self.cache.insert("elements", elements_to_insert) self.cache.insert("transcriptions", transcriptions_to_insert)
self.cache.insert("transcriptions", transcriptions_to_insert) except sqlite3.IntegrityError as e:
except sqlite3.IntegrityError as e: logger.warning(
logger.warning(f"Couldn't save created transcriptions in local cache: {e}") f"Couldn't save created transcriptions in local cache: {e}"
)
return annotations return annotations
......
No preview for this file type
No preview for this file type
...@@ -6,7 +6,7 @@ from pathlib import Path ...@@ -6,7 +6,7 @@ from pathlib import Path
import pytest import pytest
from arkindex_worker.cache import CachedElement, LocalDB from arkindex_worker.cache import CachedElement, CachedTranscription, LocalDB
from arkindex_worker.utils import convert_str_uuid_to_hex from arkindex_worker.utils import convert_str_uuid_to_hex
FIXTURES = Path(__file__).absolute().parent / "data/cache" FIXTURES = Path(__file__).absolute().parent / "data/cache"
...@@ -30,6 +30,26 @@ ELEMENTS_TO_INSERT = [ ...@@ -30,6 +30,26 @@ ELEMENTS_TO_INSERT = [
), ),
), ),
] ]
TRANSCRIPTIONS_TO_INSERT = [
CachedTranscription(
id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
element_id=convert_str_uuid_to_hex("11111111-1111-1111-1111-111111111111"),
text="Hello!",
confidence=0.42,
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
CachedTranscription(
id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
element_id=convert_str_uuid_to_hex("22222222-2222-2222-2222-222222222222"),
text="How are you?",
confidence=0.42,
worker_version_id=convert_str_uuid_to_hex(
"56785678-5678-5678-5678-567856785678"
),
),
]
def test_init_non_existent_path(): def test_init_non_existent_path():
...@@ -108,6 +128,10 @@ def test_insert_existing_lines(): ...@@ -108,6 +128,10 @@ def test_insert_existing_lines():
cache.insert("elements", ELEMENTS_TO_INSERT) cache.insert("elements", ELEMENTS_TO_INSERT)
assert str(e.value) == "UNIQUE constraint failed: elements.id" assert str(e.value) == "UNIQUE constraint failed: elements.id"
with pytest.raises(sqlite3.IntegrityError) as e:
cache.insert("transcriptions", TRANSCRIPTIONS_TO_INSERT)
assert str(e.value) == "UNIQUE constraint failed: transcriptions.id"
with open(db_path, "rb") as after_file: with open(db_path, "rb") as after_file:
after = after_file.read() after = after_file.read()
...@@ -128,3 +152,16 @@ def test_insert(): ...@@ -128,3 +152,16 @@ def test_insert():
) )
assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT assert [CachedElement(**dict(row)) for row in generated_rows] == ELEMENTS_TO_INSERT
cache.insert("transcriptions", TRANSCRIPTIONS_TO_INSERT)
generated_rows = cache.cursor.execute("SELECT * FROM transcriptions").fetchall()
expected_cache = LocalDB(f"{FIXTURES}/lines.sqlite")
assert (
generated_rows
== expected_cache.cursor.execute("SELECT * FROM transcriptions").fetchall()
)
assert [
CachedTranscription(**dict(row)) for row in generated_rows
] == TRANSCRIPTIONS_TO_INSERT
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment