Skip to content
Snippets Groups Projects

Implement worker

Merged Yoann Schneider requested to merge implem into main
2 files
+ 3
3
Compare changes
  • Side-by-side
  • Inline
Files
2
# -*- coding: utf-8 -*-
from arkindex_worker.worker import ElementsWorker
import logging
import operator
import tempfile
from argparse import Namespace
from pathlib import Path
from tempfile import _TemporaryFileWrapper
from typing import List, Optional
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_export import Element, Image, open_database
from arkindex_export.queries import list_children
from arkindex_worker.cache import (
CachedClassification,
CachedElement,
CachedEntity,
CachedImage,
CachedTranscription,
CachedTranscriptionEntity,
create_tables,
create_version_table,
)
from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db
from arkindex_worker.image import download_image
from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker.base import BaseWorker
from worker_generic_training_dataset.db import (
list_classifications,
list_transcription_entities,
list_transcriptions,
retrieve_element,
)
from worker_generic_training_dataset.utils import build_image_url
class Demo(ElementsWorker):
def process_element(self, element):
print("Demo processing element", element)
logger: logging.Logger = logging.getLogger(__name__)
BULK_BATCH_SIZE = 50
DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
class DatasetExtractor(BaseWorker):
def configure(self) -> None:
self.args: Namespace = self.parser.parse_args()
if self.is_read_only:
super().configure_for_developers()
else:
super().configure()
if self.user_configuration:
logger.info("Overriding with user_configuration")
self.config.update(self.user_configuration)
# Read process information
self.read_training_related_information()
# Download corpus
self.download_latest_export()
# Initialize db that will be written
self.configure_cache()
# CachedImage downloaded and created in DB
self.cached_images = dict()
# Where to save the downloaded images
self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data"))
logger.info(f"Images will be saved at `{self.image_folder}`.")
def read_training_related_information(self) -> None:
"""
Read from process information
- train_folder_id
- validation_folder_id
- test_folder_id (optional)
"""
logger.info("Retrieving information from process_information")
train_folder_id = self.process_information.get("train_folder_id")
assert train_folder_id, "A training folder id is necessary to use this worker"
self.training_folder_id = UUID(train_folder_id)
val_folder_id = self.process_information.get("validation_folder_id")
assert val_folder_id, "A validation folder id is necessary to use this worker"
self.validation_folder_id = UUID(val_folder_id)
test_folder_id = self.process_information.get("test_folder_id")
self.testing_folder_id: UUID | None = (
UUID(test_folder_id) if test_folder_id else None
)
def configure_cache(self) -> None:
"""
Create an SQLite database compatible with base-worker cache and initialize it.
"""
self.use_cache = True
self.cache_path: Path = self.args.database or self.work_dir / "db.sqlite"
# Remove previous execution result if present
self.cache_path.unlink(missing_ok=True)
init_cache_db(self.cache_path)
create_version_table()
create_tables()
def download_latest_export(self) -> None:
"""
Download the latest export of the current corpus.
Export must be in `"done"` state.
"""
try:
exports = list(
self.api_client.paginate(
"ListExports",
id=self.corpus_id,
)
)
except ErrorResponse as e:
logger.error(
f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}"
)
raise e
# Find the latest that is in "done" state
exports: List[dict] = sorted(
list(filter(lambda exp: exp["state"] == "done", exports)),
key=operator.itemgetter("updated"),
reverse=True,
)
assert (
len(exports) > 0
), f"No available exports found for the corpus {self.corpus_id}."
# Download latest export
try:
export_id: str = exports[0]["id"]
logger.info(f"Downloading export ({export_id})...")
self.export: _TemporaryFileWrapper = self.api_client.request(
"DownloadExport",
id=export_id,
)
logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`")
open_database(self.export.name)
except ErrorResponse as e:
logger.error(
f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}"
)
raise e
def insert_classifications(self, element: CachedElement) -> None:
logger.info("Listing classifications")
classifications: list[CachedClassification] = [
CachedClassification(
id=classification.id,
element=element,
class_name=classification.class_name,
confidence=classification.confidence,
state=classification.state,
)
for classification in list_classifications(element.id)
]
if classifications:
logger.info(f"Inserting {len(classifications)} classification(s)")
with cache_database.atomic():
CachedClassification.bulk_create(
model_list=classifications,
batch_size=BULK_BATCH_SIZE,
)
def insert_transcriptions(
self, element: CachedElement
) -> List[CachedTranscription]:
logger.info("Listing transcriptions")
transcriptions: list[CachedTranscription] = [
CachedTranscription(
id=transcription.id,
element=element,
text=transcription.text,
# Dodge not-null constraint for now
confidence=transcription.confidence or 1.0,
orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
worker_version_id=transcription.worker_version.id
if transcription.worker_version
else None,
)
for transcription in list_transcriptions(element.id)
]
if transcriptions:
logger.info(f"Inserting {len(transcriptions)} transcription(s)")
with cache_database.atomic():
CachedTranscription.bulk_create(
model_list=transcriptions,
batch_size=BULK_BATCH_SIZE,
)
return transcriptions
def insert_entities(self, transcriptions: List[CachedTranscription]) -> None:
logger.info("Listing entities")
entities: List[CachedEntity] = []
transcription_entities: List[CachedTranscriptionEntity] = []
for transcription in transcriptions:
for transcription_entity in list_transcription_entities(transcription.id):
entity = CachedEntity(
id=transcription_entity.entity.id,
type=transcription_entity.entity.type.name,
name=transcription_entity.entity.name,
validated=transcription_entity.entity.validated,
metas=transcription_entity.entity.metas,
)
entities.append(entity)
transcription_entities.append(
CachedTranscriptionEntity(
id=transcription_entity.id,
transcription=transcription,
entity=entity,
offset=transcription_entity.offset,
length=transcription_entity.length,
confidence=transcription_entity.confidence,
)
)
if entities:
# First insert entities since they are foreign keys on transcription entities
logger.info(f"Inserting {len(entities)} entities")
with cache_database.atomic():
CachedEntity.bulk_create(
model_list=entities,
batch_size=BULK_BATCH_SIZE,
)
if transcription_entities:
# Insert transcription entities
logger.info(
f"Inserting {len(transcription_entities)} transcription entities"
)
with cache_database.atomic():
CachedTranscriptionEntity.bulk_create(
model_list=transcription_entities,
batch_size=BULK_BATCH_SIZE,
)
def insert_element(
self, element: Element, parent_id: Optional[UUID] = None
) -> None:
"""
Insert the given element in the cache database.
Its image will also be saved to disk, if it wasn't already.
The insertion of an element includes:
- its classifications
- its transcriptions
- its transcriptions' entities (both Entity and TranscriptionEntity)
:param element: Element to insert.
:param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
"""
logger.info(f"Processing element ({element.id})")
if element.image and element.image.id not in self.cached_images:
# Download image
logger.info("Downloading image")
download_image(url=build_image_url(element)).save(
self.image_folder / f"{element.image.id}.jpg"
)
# Insert image
logger.info("Inserting image")
# Store images in case some other elements use it as well
with cache_database.atomic():
self.cached_images[element.image.id] = CachedImage.create(
id=element.image.id,
width=element.image.width,
height=element.image.height,
url=element.image.url,
)
# Insert element
logger.info("Inserting element")
with cache_database.atomic():
cached_element: CachedElement = CachedElement.create(
id=element.id,
parent_id=parent_id,
type=element.type,
image=self.cached_images[element.image.id] if element.image else None,
polygon=element.polygon,
rotation_angle=element.rotation_angle,
mirrored=element.mirrored,
worker_version_id=element.worker_version.id
if element.worker_version
else None,
confidence=element.confidence,
)
# Insert classifications
self.insert_classifications(cached_element)
# Insert transcriptions
transcriptions: List[CachedTranscription] = self.insert_transcriptions(
cached_element
)
# Insert entities
self.insert_entities(transcriptions)
def process_split(self, split_name: str, split_id: UUID) -> None:
"""
Insert all elements under the given parent folder (all queries are recursive).
- `page` elements are linked to this folder (via parent_id foreign key)
- `page` element children are linked to their `page` parent (via parent_id foreign key)
"""
logger.info(
f"Filling the Base-Worker cache with information from children under element ({split_id})"
)
# Fill cache
# Retrieve parent and create parent
parent: Element = retrieve_element(split_id)
self.insert_element(parent)
# First list all pages
pages = list_children(split_id).join(Image).where(Element.type == "page")
nb_pages: int = pages.count()
for idx, page in enumerate(pages, start=1):
logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})")
# Insert page
self.insert_element(page, parent_id=split_id)
# List children
children = list_children(page.id)
nb_children: int = children.count()
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx}/{nb_children})")
# Insert child
self.insert_element(child, parent_id=page.id)
def run(self):
self.configure()
# Iterate over given split
for split_name, split_id in [
("Train", self.training_folder_id),
("Validation", self.validation_folder_id),
("Test", self.testing_folder_id),
]:
if not split_id:
continue
self.process_split(split_name, split_id)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path: Path = self.work_dir / "arkindex_data.zstd"
logger.info(f"Compressing the images to {zstd_archive_path}")
create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path)
def main():
Demo(
description="Fill base-worker cache with information about dataset and extract images"
DatasetExtractor(
description="Fill base-worker cache with information about dataset and extract images",
support_cache=True,
).run()
Loading