Skip to content
Snippets Groups Projects

Draft: Refactor and implement API version of the worker

Open Yoann Schneider requested to merge new-api-worker into main
1 file
+ 0
1
Compare changes
  • Side-by-side
  • Inline
+ 110
0
# -*- coding: utf-8 -*-
from collections.abc import Iterator
from arkindex_worker.cache import (
CachedClassification,
CachedElement,
CachedEntity,
CachedTranscription,
CachedTranscriptionEntity,
)
from arkindex_worker.models import Element, Set, Transcription
from arkindex_worker.worker.classification import ClassificationMixin
from arkindex_worker.worker.element import ElementMixin
from arkindex_worker.worker.entity import EntityMixin
from arkindex_worker.worker.metadata import MetaDataMixin
from arkindex_worker.worker.transcription import TranscriptionMixin
from worker_generic_training_dataset import DEFAULT_TRANSCRIPTION_ORIENTATION, Extractor
from worker_generic_training_dataset.utils import get_id_or_null
class DatasetExtractorFromAPI(
Extractor,
ElementMixin,
ClassificationMixin,
EntityMixin,
TranscriptionMixin,
MetaDataMixin,
):
def list_set_elements(self, dataset_set: Set) -> Iterator[Element]:
for element in super().list_set_elements(dataset_set):
# Classifications are not serialized in ListDatasetElements
yield Element(
**self.request(
"RetrieveElement",
id=element.id,
)
)
def list_element_children(self, *args, **kwargs) -> Iterator[Element]:
return map(
Element, super().list_element_children(*args, **kwargs, with_classes=True)
)
def get_classifications(self, element: CachedElement, classifications: list[dict]):
return [
CachedClassification(
id=classification["id"],
element=element,
class_name=classification["ml_class"]["name"],
confidence=classification["confidence"],
state=classification["state"],
worker_run_id=get_id_or_null(classification["worker_run"]),
)
for classification in classifications
]
def get_transcriptions(self, element: CachedElement):
return [
CachedTranscription(
id=transcription["id"],
element=element,
text=transcription["text"],
confidence=transcription["confidence"],
orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
worker_version_id=transcription["worker_version_id"],
worker_run_id=get_id_or_null(transcription["worker_run"]),
)
for transcription in self.list_transcriptions(element)
]
def get_transcription_entities(
self, transcription: CachedTranscription
) -> tuple[list[CachedEntity], list[CachedTranscriptionEntity]]:
entities: list[CachedEntity] = []
transcription_entities: list[CachedTranscriptionEntity] = []
for transcription_entity in self.list_transcription_entities(
Transcription(id=transcription.id)
):
ark_entity = transcription_entity["entity"]
entity = CachedEntity(
id=ark_entity["id"],
type=ark_entity["type"]["name"],
name=ark_entity["name"],
validated=ark_entity["validated"],
metas=ark_entity["metas"],
worker_run_id=get_id_or_null(ark_entity["worker_run"]),
)
entities.append(entity)
transcription_entities.append(
CachedTranscriptionEntity(
transcription=transcription,
entity=entity,
offset=transcription_entity["offset"],
length=transcription_entity["length"],
confidence=transcription_entity["confidence"],
worker_run_id=get_id_or_null(transcription_entity["worker_run"]),
)
)
return entities, transcription_entities
def main():
DatasetExtractorFromAPI(
description="Fill base-worker cache with information about dataset and extract images",
).run()
if __name__ == "__main__":
main()
Loading