Skip to content
Snippets Groups Projects
Verified Commit 5f87f70f authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Fix API worker (push ml results)

parent 83ce2bfe
No related branches found
No related tags found
1 merge request!25Draft: Refactor and implement API version of the worker
Pipeline #170947 passed
......@@ -31,7 +31,6 @@ from arkindex_worker.image import download_image
from arkindex_worker.models import Dataset
from arkindex_worker.models import Element as ArkindexElement
from arkindex_worker.models import Set
from arkindex_worker.models import Transcription as ArkindexTranscription
from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker import DatasetWorker
from arkindex_worker.worker.dataset import DatasetState
......@@ -68,7 +67,6 @@ class Extractor(DatasetWorker):
"""
Create an SQLite database compatible with base-worker cache and initialize it.
"""
self.use_cache = True
self.cache_path: Path = self.data_folder_path / "db.sqlite"
logger.info(f"Cached database will be saved at `{self.cache_path}`.")
......@@ -80,17 +78,7 @@ class Extractor(DatasetWorker):
def insert_classifications(self, element: CachedElement) -> None:
logger.info("Listing classifications")
classifications: list[CachedClassification] = [
CachedClassification(
id=classification.id,
element=element,
class_name=classification.class_name,
confidence=classification.confidence,
state=classification.state,
worker_run_id=get_object_id(classification.worker_run),
)
for classification in self.list_classifications(element.id)
]
classifications: list[CachedClassification] = self.get_classifications(element)
if classifications:
logger.info(f"Inserting {len(classifications)} classification(s)")
with cache_database.atomic():
......@@ -103,20 +91,7 @@ class Extractor(DatasetWorker):
self, element: CachedElement
) -> List[CachedTranscription]:
logger.info("Listing transcriptions")
transcriptions: list[CachedTranscription] = [
CachedTranscription(
id=transcription.id,
element=element,
text=transcription.text,
confidence=transcription.confidence,
orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
worker_version_id=get_object_id(transcription.worker_version),
worker_run_id=get_object_id(transcription.worker_run),
)
for transcription in self.list_transcriptions(
ArkindexElement(id=element.id)
)
]
transcriptions: list[CachedTranscription] = self.get_transcriptions(element)
if transcriptions:
logger.info(f"Inserting {len(transcriptions)} transcription(s)")
with cache_database.atomic():
......@@ -131,29 +106,10 @@ class Extractor(DatasetWorker):
entities: List[CachedEntity] = []
transcription_entities: List[CachedTranscriptionEntity] = []
for transcription in transcriptions:
for transcription_entity in self.list_transcription_entities(
ArkindexTranscription(id=transcription.id)
):
entity = CachedEntity(
id=transcription_entity.entity.id,
type=transcription_entity.entity.type.name,
name=transcription_entity.entity.name,
validated=transcription_entity.entity.validated,
metas=transcription_entity.entity.metas,
worker_run_id=get_object_id(transcription_entity.entity.worker_run),
)
entities.append(entity)
transcription_entities.append(
CachedTranscriptionEntity(
id=transcription_entity.id,
transcription=transcription,
entity=entity,
offset=transcription_entity.offset,
length=transcription_entity.length,
confidence=transcription_entity.confidence,
worker_run_id=get_object_id(transcription_entity.worker_run),
)
)
parsed_entities = self.get_transcription_entities(transcription)
entities.extend(parsed_entities[0])
transcription_entities.extend(parsed_entities[1])
if entities:
# First insert entities since they are foreign keys on transcription entities
logger.info(f"Inserting {len(entities)} entities")
......@@ -194,7 +150,7 @@ class Extractor(DatasetWorker):
:param element: Element to insert.
:param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
"""
logger.info(f"Processing element ({element.id})")
logger.info(f"Processing element ({element})")
if isinstance(element, Element):
image = element.image
......@@ -271,15 +227,14 @@ class Extractor(DatasetWorker):
f"Filling the cache with information from elements in the split {split_name}"
)
for idx, element in enumerate(elements, start=1):
logger.info(f"Processing `{split_name}` element ({idx})")
logger.info(f"Processing `{split_name}` element (n°{idx})")
# Insert page
self.insert_element(element, split_name=split_name)
# List children
children = self.list_element_children(element)
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx})")
logger.info(f"Processing {child} ({child_idx})")
# Insert child
self.insert_element(child, parent_id=element.id)
......
# -*- coding: utf-8 -*-
from uuid import UUID
from collections.abc import Iterator
from arkindex_worker.models import MagicDict
from arkindex_worker.cache import (
CachedClassification,
CachedElement,
CachedEntity,
CachedTranscription,
CachedTranscriptionEntity,
)
from arkindex_worker.models import Element, Transcription
from arkindex_worker.worker.classification import ClassificationMixin
from arkindex_worker.worker.element import ElementMixin
from arkindex_worker.worker.entity import EntityMixin
from arkindex_worker.worker.metadata import MetaDataMixin
from arkindex_worker.worker.transcription import TranscriptionMixin
from worker_generic_training_dataset import Extractor
from worker_generic_training_dataset import DEFAULT_TRANSCRIPTION_ORIENTATION, Extractor
from worker_generic_training_dataset.utils import get_id_or_null
class DatasetExtractorFromAPI(
......@@ -18,13 +26,69 @@ class DatasetExtractorFromAPI(
TranscriptionMixin,
MetaDataMixin,
):
def list_classifications(self, element_id: UUID):
return map(
MagicDict,
self.api_client.request("RetrieveElement", id=str(element_id))[
"classifications"
],
)
def list_element_children(self, *args, **kwargs) -> Iterator[Element]:
return map(Element, super().list_element_children(*args, **kwargs))
def get_classifications(self, element: CachedElement):
return [
CachedClassification(
id=classification["id"],
element=element,
class_name=classification["ml_class"]["name"],
confidence=classification["confidence"],
state=classification["state"],
worker_run_id=get_id_or_null(classification["worker_run"]),
)
for classification in self.api_client.request(
"RetrieveElement", id=str(element.id)
)["classifications"]
]
def get_transcriptions(self, element: CachedElement):
return [
CachedTranscription(
id=transcription["id"],
element=element,
text=transcription["text"],
confidence=transcription["confidence"],
orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
worker_version_id=transcription["worker_version_id"],
worker_run_id=get_id_or_null(transcription["worker_run"]),
)
for transcription in self.list_transcriptions(element)
]
def get_transcription_entities(
self, transcription: CachedTranscription
) -> tuple[list[CachedEntity], list[CachedTranscriptionEntity]]:
entities: list[CachedEntity] = []
transcription_entities: list[CachedTranscriptionEntity] = []
for transcription_entity in self.list_transcription_entities(
Transcription(id=transcription.id)
):
ark_entity = transcription_entity["entity"]
entity = CachedEntity(
id=ark_entity["id"],
type=ark_entity["type"]["name"],
name=ark_entity["name"],
validated=ark_entity["validated"],
metas=ark_entity["metas"],
worker_run_id=get_id_or_null(ark_entity["worker_run"]),
)
entities.append(entity)
transcription_entities.append(
CachedTranscriptionEntity(
id=transcription_entity["id"],
transcription=transcription,
entity=entity,
offset=transcription_entity["offset"],
length=transcription_entity["length"],
confidence=transcription_entity["confidence"],
worker_run_id=get_id_or_null(transcription_entity["worker_run"]),
)
)
return entities, transcription_entities
def main():
......
......@@ -3,7 +3,6 @@ import logging
from operator import itemgetter
from tempfile import _TemporaryFileWrapper
from typing import List
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_export import (
......@@ -20,6 +19,13 @@ from arkindex_export import (
open_database,
)
from arkindex_export.queries import list_children
from arkindex_worker.cache import (
CachedClassification,
CachedElement,
CachedEntity,
CachedTranscription,
CachedTranscriptionEntity,
)
from arkindex_worker.models import Element as ArkindexElement
from arkindex_worker.models import Set
from arkindex_worker.models import Transcription as ArkindexTranscription
......@@ -97,18 +103,6 @@ class DatasetExtractorFromSQL(Extractor):
)
)
def list_classifications(self, element_id: UUID):
return (
Classification.select()
.where(Classification.element == element_id)
.iterator()
)
def list_transcriptions(self, element: ArkindexElement, **kwargs):
return (
Transcription.select().where(Transcription.element == element.id).iterator()
)
def list_transcription_entities(
self, transcription: ArkindexTranscription, **kwargs
):
......@@ -122,6 +116,70 @@ class DatasetExtractorFromSQL(Extractor):
def list_element_children(self, element: ArkindexElement, **kwargs):
return list_children(element.id).iterator()
def get_classifications(self, element: CachedElement) -> list[CachedClassification]:
return [
CachedClassification(
id=classification.id,
element=element,
class_name=classification.class_name,
confidence=classification.confidence,
state=classification.state,
worker_run_id=get_object_id(classification.worker_run),
)
for classification in (
Classification.select()
.where(Classification.element == element.id)
.iterator()
)
]
def get_transcriptions(self, element: CachedElement) -> list[CachedTranscription]:
return [
CachedTranscription(
id=transcription.id,
element=element,
text=transcription.text,
confidence=transcription.confidence,
orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
worker_version_id=get_object_id(transcription.worker_version),
worker_run_id=get_object_id(transcription.worker_run),
)
for transcription in Transcription.select()
.where(Transcription.element == element.id)
.iterator()
]
def get_transcription_entities(
self, transcription: CachedTranscription
) -> tuple[list[CachedEntity], list[CachedTranscriptionEntity]]:
entities: List[CachedEntity] = []
transcription_entities: List[CachedTranscriptionEntity] = []
for transcription_entity in self.list_transcription_entities(
ArkindexTranscription(id=transcription.id)
):
entity = CachedEntity(
id=transcription_entity.entity.id,
type=transcription_entity.entity.type.name,
name=transcription_entity.entity.name,
validated=transcription_entity.entity.validated,
metas=transcription_entity.entity.metas,
worker_run_id=get_object_id(transcription_entity.entity.worker_run),
)
entities.append(entity)
transcription_entities.append(
CachedTranscriptionEntity(
id=transcription_entity.id,
transcription=transcription,
entity=entity,
offset=transcription_entity.offset,
length=transcription_entity.length,
confidence=transcription_entity.confidence,
worker_run_id=get_object_id(transcription_entity.worker_run),
)
)
return entities, transcription_entities
def main():
DatasetExtractorFromSQL(
......
......@@ -20,3 +20,7 @@ def build_image_url(image, polygon: str | list[list[int]]) -> str:
height: int
x, y, width, height = bbox
return urljoin(image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg")
def get_id_or_null(value: dict | None) -> str | None:
return value["id"] if value else None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment