Skip to content
Snippets Groups Projects

Draft: Refactor and implement API version of the worker

Open Yoann Schneider requested to merge new-api-worker into main
Files
4
@@ -31,7 +31,6 @@ from arkindex_worker.image import download_image
from arkindex_worker.models import Dataset
from arkindex_worker.models import Element as ArkindexElement
from arkindex_worker.models import Set
from arkindex_worker.models import Transcription as ArkindexTranscription
from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker import DatasetWorker
from arkindex_worker.worker.dataset import DatasetState
@@ -68,7 +67,6 @@ class Extractor(DatasetWorker):
"""
Create an SQLite database compatible with base-worker cache and initialize it.
"""
self.use_cache = True
self.cache_path: Path = self.data_folder_path / "db.sqlite"
logger.info(f"Cached database will be saved at `{self.cache_path}`.")
@@ -80,17 +78,7 @@ class Extractor(DatasetWorker):
def insert_classifications(self, element: CachedElement) -> None:
logger.info("Listing classifications")
classifications: list[CachedClassification] = [
CachedClassification(
id=classification.id,
element=element,
class_name=classification.class_name,
confidence=classification.confidence,
state=classification.state,
worker_run_id=get_object_id(classification.worker_run),
)
for classification in self.list_classifications(element.id)
]
classifications: list[CachedClassification] = self.get_classifications(element)
if classifications:
logger.info(f"Inserting {len(classifications)} classification(s)")
with cache_database.atomic():
@@ -103,20 +91,7 @@ class Extractor(DatasetWorker):
self, element: CachedElement
) -> List[CachedTranscription]:
logger.info("Listing transcriptions")
transcriptions: list[CachedTranscription] = [
CachedTranscription(
id=transcription.id,
element=element,
text=transcription.text,
confidence=transcription.confidence,
orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
worker_version_id=get_object_id(transcription.worker_version),
worker_run_id=get_object_id(transcription.worker_run),
)
for transcription in self.list_transcriptions(
ArkindexElement(id=element.id)
)
]
transcriptions: list[CachedTranscription] = self.get_transcriptions(element)
if transcriptions:
logger.info(f"Inserting {len(transcriptions)} transcription(s)")
with cache_database.atomic():
@@ -131,29 +106,10 @@ class Extractor(DatasetWorker):
entities: List[CachedEntity] = []
transcription_entities: List[CachedTranscriptionEntity] = []
for transcription in transcriptions:
for transcription_entity in self.list_transcription_entities(
ArkindexTranscription(id=transcription.id)
):
entity = CachedEntity(
id=transcription_entity.entity.id,
type=transcription_entity.entity.type.name,
name=transcription_entity.entity.name,
validated=transcription_entity.entity.validated,
metas=transcription_entity.entity.metas,
worker_run_id=get_object_id(transcription_entity.entity.worker_run),
)
entities.append(entity)
transcription_entities.append(
CachedTranscriptionEntity(
id=transcription_entity.id,
transcription=transcription,
entity=entity,
offset=transcription_entity.offset,
length=transcription_entity.length,
confidence=transcription_entity.confidence,
worker_run_id=get_object_id(transcription_entity.worker_run),
)
)
parsed_entities = self.get_transcription_entities(transcription)
entities.extend(parsed_entities[0])
transcription_entities.extend(parsed_entities[1])
if entities:
# First insert entities since they are foreign keys on transcription entities
logger.info(f"Inserting {len(entities)} entities")
@@ -194,7 +150,7 @@ class Extractor(DatasetWorker):
:param element: Element to insert.
:param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
"""
logger.info(f"Processing element ({element.id})")
logger.info(f"Processing element ({element})")
if isinstance(element, Element):
image = element.image
@@ -271,15 +227,14 @@ class Extractor(DatasetWorker):
f"Filling the cache with information from elements in the split {split_name}"
)
for idx, element in enumerate(elements, start=1):
logger.info(f"Processing `{split_name}` element ({idx})")
logger.info(f"Processing `{split_name}` element (n°{idx})")
# Insert page
self.insert_element(element, split_name=split_name)
# List children
children = self.list_element_children(element)
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx})")
logger.info(f"Processing {child} ({child_idx})")
# Insert child
self.insert_element(child, parent_id=element.id)
Loading