Skip to content
Snippets Groups Projects

New DatasetExtractor using a DatasetWorker

Merged Eva Bardou requested to merge dataset-worker into main
All threads resolved!
3 files
+ 159
527
Compare changes
  • Side-by-side
  • Inline
Files
3
# -*- coding: utf-8 -*-
import logging
import sys
import tempfile
from argparse import Namespace
from itertools import groupby
from operator import itemgetter
from pathlib import Path
from tempfile import _TemporaryFileWrapper
from typing import Iterator, List, Optional, Tuple
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_export import Element, open_database
from arkindex_export.queries import list_children
from arkindex_worker.cache import (
CachedClassification,
CachedElement,
CachedEntity,
CachedImage,
CachedTranscription,
CachedTranscriptionEntity,
create_tables,
create_version_table,
)
from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db
from arkindex_worker.image import download_image
from arkindex_worker.models import Dataset
from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker.base import BaseWorker
from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
from worker_generic_training_dataset.db import (
list_classifications,
list_transcription_entities,
list_transcriptions,
)
from worker_generic_training_dataset.utils import build_image_url
from worker_generic_training_dataset.worker import (
BULK_BATCH_SIZE,
DEFAULT_TRANSCRIPTION_ORIENTATION,
)
logger: logging.Logger = logging.getLogger(__name__)
class DatasetWorker(BaseWorker, DatasetMixin):
def __init__(
self,
description: str = "Arkindex Elements Worker",
support_cache: bool = False,
generator: bool = False,
):
super().__init__(description, support_cache)
self.parser.add_argument(
"--dataset",
type=UUID,
nargs="+",
help="One or more Arkindex dataset ID",
)
self.generator = generator
def list_dataset_elements_per_set(
self, dataset: Dataset
) -> Iterator[Tuple[str, Element]]:
"""
Calls `list_dataset_elements` but returns results grouped by Set
"""
def format_element(element):
return Element.get(Element.id == element[1].id)
def format_set(set):
return (set[0], list(map(format_element, list(set[1]))))
return list(
map(
format_set,
groupby(
sorted(self.list_dataset_elements(dataset), key=itemgetter(0)),
key=itemgetter(0),
),
)
)
def process_dataset(self, dataset: Dataset):
"""
Override this method to implement your worker and process a single Arkindex dataset at once.
:param dataset: The dataset to process.
"""
def list_datasets(self) -> List[Dataset] | List[str]:
"""
Calls `list_process_datasets` if not is_read_only,
else simply give the list of IDs provided via CLI
"""
if self.is_read_only:
return list(map(str, self.args.dataset))
return self.list_process_datasets()
def run(self):
self.configure()
datasets: List[Dataset] | List[str] = self.list_datasets()
if not datasets:
logger.warning("No datasets to process, stopping.")
sys.exit(1)
# Process every dataset
count = len(datasets)
failed = 0
for i, item in enumerate(datasets, start=1):
dataset = None
try:
if not self.is_read_only:
# Just use the result of list_datasets as the dataset
dataset = item
else:
# Load dataset using the Arkindex API
dataset = Dataset(**self.request("RetrieveDataset", id=item))
if self.generator:
assert (
dataset.state == DatasetState.Open.value
), "When generating a new dataset, its state should be Open"
else:
assert (
dataset.state == DatasetState.Complete.value
), "When processing an existing dataset, its state should be Complete"
if self.generator:
# Update the dataset state to Building
logger.info(f"Building {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Building)
# Process the dataset
self.process_dataset(dataset)
if self.generator:
# Update the dataset state to Complete
logger.info(f"Completed {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Complete)
except Exception as e:
# Handle errors occurring while retrieving, processing or patching the state for this dataset.
failed += 1
# Handle the case where we failed retrieving the dataset
dataset_id = dataset.id if dataset else item
if isinstance(e, ErrorResponse):
message = f"An API error occurred while processing dataset {dataset_id}: {e.title} - {e.content}"
else:
message = (
f"Failed running worker on dataset {dataset_id}: {repr(e)}"
)
logger.warning(
message,
exc_info=e if self.args.verbose else None,
)
if dataset and self.generator:
# Try to update the state to Error regardless of the response
try:
self.update_dataset_state(dataset, DatasetState.Error)
except Exception:
pass
if failed:
logger.error(
"Ran on {} dataset: {} completed, {} failed".format(
count, count - failed, failed
)
)
if failed >= count: # Everything failed!
sys.exit(1)
class DatasetExtractor(DatasetWorker):
def configure(self) -> None:
self.args: Namespace = self.parser.parse_args()
if self.is_read_only:
super().configure_for_developers()
else:
super().configure()
if self.user_configuration:
logger.info("Overriding with user_configuration")
self.config.update(self.user_configuration)
# Download corpus
self.download_latest_export()
# Initialize db that will be written
self.configure_cache()
# CachedImage downloaded and created in DB
self.cached_images = dict()
# Where to save the downloaded images
self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data"))
logger.info(f"Images will be saved at `{self.image_folder}`.")
def configure_cache(self) -> None:
"""
Create an SQLite database compatible with base-worker cache and initialize it.
"""
self.use_cache = True
self.cache_path: Path = self.args.database or self.work_dir / "db.sqlite"
# Remove previous execution result if present
self.cache_path.unlink(missing_ok=True)
init_cache_db(self.cache_path)
create_version_table()
create_tables()
def download_latest_export(self) -> None:
"""
Download the latest export of the current corpus.
Export must be in `"done"` state.
"""
try:
exports = list(
self.api_client.paginate(
"ListExports",
id=self.corpus_id,
)
)
except ErrorResponse as e:
logger.error(
f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}"
)
raise e
# Find the latest that is in "done" state
exports: List[dict] = sorted(
list(filter(lambda exp: exp["state"] == "done", exports)),
key=itemgetter("updated"),
reverse=True,
)
assert (
len(exports) > 0
), f"No available exports found for the corpus {self.corpus_id}."
# Download latest export
try:
export_id: str = exports[0]["id"]
logger.info(f"Downloading export ({export_id})...")
self.export: _TemporaryFileWrapper = self.api_client.request(
"DownloadExport",
id=export_id,
)
logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`")
open_database(self.export.name)
except ErrorResponse as e:
logger.error(
f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}"
)
raise e
def insert_classifications(self, element: CachedElement) -> None:
logger.info("Listing classifications")
classifications: list[CachedClassification] = [
CachedClassification(
id=classification.id,
element=element,
class_name=classification.class_name,
confidence=classification.confidence,
state=classification.state,
worker_run_id=classification.worker_run,
)
for classification in list_classifications(element.id)
]
if classifications:
logger.info(f"Inserting {len(classifications)} classification(s)")
with cache_database.atomic():
CachedClassification.bulk_create(
model_list=classifications,
batch_size=BULK_BATCH_SIZE,
)
def insert_transcriptions(
self, element: CachedElement
) -> List[CachedTranscription]:
logger.info("Listing transcriptions")
transcriptions: list[CachedTranscription] = [
CachedTranscription(
id=transcription.id,
element=element,
text=transcription.text,
confidence=transcription.confidence,
orientation=DEFAULT_TRANSCRIPTION_ORIENTATION,
worker_version_id=transcription.worker_version,
worker_run_id=transcription.worker_run,
)
for transcription in list_transcriptions(element.id)
]
if transcriptions:
logger.info(f"Inserting {len(transcriptions)} transcription(s)")
with cache_database.atomic():
CachedTranscription.bulk_create(
model_list=transcriptions,
batch_size=BULK_BATCH_SIZE,
)
return transcriptions
def insert_entities(self, transcriptions: List[CachedTranscription]) -> None:
logger.info("Listing entities")
entities: List[CachedEntity] = []
transcription_entities: List[CachedTranscriptionEntity] = []
for transcription in transcriptions:
for transcription_entity in list_transcription_entities(transcription.id):
entity = CachedEntity(
id=transcription_entity.entity.id,
type=transcription_entity.entity.type.name,
name=transcription_entity.entity.name,
validated=transcription_entity.entity.validated,
metas=transcription_entity.entity.metas,
worker_run_id=transcription_entity.entity.worker_run,
)
entities.append(entity)
transcription_entities.append(
CachedTranscriptionEntity(
id=transcription_entity.id,
transcription=transcription,
entity=entity,
offset=transcription_entity.offset,
length=transcription_entity.length,
confidence=transcription_entity.confidence,
worker_run_id=transcription_entity.worker_run,
)
)
if entities:
# First insert entities since they are foreign keys on transcription entities
logger.info(f"Inserting {len(entities)} entities")
with cache_database.atomic():
CachedEntity.bulk_create(
model_list=entities,
batch_size=BULK_BATCH_SIZE,
)
if transcription_entities:
# Insert transcription entities
logger.info(
f"Inserting {len(transcription_entities)} transcription entities"
)
with cache_database.atomic():
CachedTranscriptionEntity.bulk_create(
model_list=transcription_entities,
batch_size=BULK_BATCH_SIZE,
)
def insert_element(
self, element: Element, parent_id: Optional[UUID] = None
) -> None:
"""
Insert the given element in the cache database.
Its image will also be saved to disk, if it wasn't already.
The insertion of an element includes:
- its classifications
- its transcriptions
- its transcriptions' entities (both Entity and TranscriptionEntity)
:param element: Element to insert.
:param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
"""
logger.info(f"Processing element ({element.id})")
if element.image and element.image.id not in self.cached_images:
# Download image
logger.info("Downloading image")
download_image(url=build_image_url(element)).save(
self.image_folder / f"{element.image.id}.jpg"
)
# Insert image
logger.info("Inserting image")
# Store images in case some other elements use it as well
with cache_database.atomic():
self.cached_images[element.image.id] = CachedImage.create(
id=element.image.id,
width=element.image.width,
height=element.image.height,
url=element.image.url,
)
# Insert element
logger.info("Inserting element")
with cache_database.atomic():
cached_element: CachedElement = CachedElement.create(
id=element.id,
parent_id=parent_id,
type=element.type,
image=self.cached_images[element.image.id] if element.image else None,
polygon=element.polygon,
rotation_angle=element.rotation_angle,
mirrored=element.mirrored,
worker_version_id=element.worker_version,
worker_run_id=element.worker_run,
confidence=element.confidence,
)
# Insert classifications
self.insert_classifications(cached_element)
# Insert transcriptions
transcriptions: List[CachedTranscription] = self.insert_transcriptions(
cached_element
)
# Insert entities
self.insert_entities(transcriptions)
def process_set(self, set_name: str, elements: List[Element]) -> None:
logger.info(
f"Filling the cache with information from elements in the set {set_name}"
)
# First list all pages
nb_elements: int = len(elements)
for idx, element in enumerate(elements, start=1):
logger.info(f"Processing `{set_name}` element ({idx}/{nb_elements})")
# Insert page
self.insert_element(element)
# List children
children = list_children(element.id)
nb_children: int = children.count()
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx}/{nb_children})")
# Insert child
self.insert_element(child, parent_id=element.id)
def process_dataset(self, dataset: Dataset):
# Iterate over given sets
for set_name, elements in self.list_dataset_elements_per_set(dataset):
self.process_set(set_name, elements)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"
logger.info(f"Compressing the images to {zstd_archive_path}")
create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path)
def main():
DatasetExtractor(
description="Fill base-worker cache with information about dataset and extract images",
generator=True,
).run()
if __name__ == "__main__":
main()
Loading