diff --git a/setup.py b/setup.py index 81f4440affe641e4266ae08efa2af63c9b40c62c..4f4d657fc81d35812fb2f4bff9741a22d97fcb44 100755 --- a/setup.py +++ b/setup.py @@ -47,6 +47,11 @@ setup( author="Teklia", author_email="contact@teklia.com", install_requires=parse_requirements(), - entry_points={"console_scripts": [f"{COMMAND}={MODULE}.worker:main"]}, + entry_points={ + "console_scripts": [ + f"{COMMAND}={MODULE}.worker:main", + "worker-generic-training-dataset-new=worker_generic_training_dataset.dataset_worker:main", + ] + }, packages=find_packages(), ) diff --git a/worker_generic_training_dataset/dataset_worker.py b/worker_generic_training_dataset/dataset_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..a82c4649360a45447f5f8826a3c187cf4c6369e1 --- /dev/null +++ b/worker_generic_training_dataset/dataset_worker.py @@ -0,0 +1,458 @@ +# -*- coding: utf-8 -*- +import logging +import sys +import tempfile +from argparse import Namespace +from itertools import groupby +from operator import itemgetter +from pathlib import Path +from tempfile import _TemporaryFileWrapper +from typing import Iterator, List, Optional, Tuple +from uuid import UUID + +from apistar.exceptions import ErrorResponse +from arkindex_export import Element, open_database +from arkindex_export.queries import list_children +from arkindex_worker.cache import ( + CachedClassification, + CachedElement, + CachedEntity, + CachedImage, + CachedTranscription, + CachedTranscriptionEntity, + create_tables, + create_version_table, +) +from arkindex_worker.cache import db as cache_database +from arkindex_worker.cache import init_cache_db +from arkindex_worker.image import download_image +from arkindex_worker.models import Dataset +from arkindex_worker.utils import create_tar_zst_archive +from arkindex_worker.worker.base import BaseWorker +from arkindex_worker.worker.dataset import DatasetMixin, DatasetState +from worker_generic_training_dataset.db import ( + list_classifications, + list_transcription_entities, + list_transcriptions, +) +from worker_generic_training_dataset.utils import build_image_url +from worker_generic_training_dataset.worker import ( + BULK_BATCH_SIZE, + DEFAULT_TRANSCRIPTION_ORIENTATION, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +class DatasetWorker(BaseWorker, DatasetMixin): + def __init__( + self, + description: str = "Arkindex Elements Worker", + support_cache: bool = False, + generator: bool = False, + ): + super().__init__(description, support_cache) + + self.parser.add_argument( + "--dataset", + type=UUID, + nargs="+", + help="One or more Arkindex dataset ID", + ) + + self.generator = generator + + def list_dataset_elements_per_set( + self, dataset: Dataset + ) -> Iterator[Tuple[str, Element]]: + """ + Calls `list_dataset_elements` but returns results grouped by Set + """ + + def format_element(element): + return Element.get(Element.id == element[1].id) + + def format_set(set): + return (set[0], list(map(format_element, list(set[1])))) + + return list( + map( + format_set, + groupby( + sorted(self.list_dataset_elements(dataset), key=itemgetter(0)), + key=itemgetter(0), + ), + ) + ) + + def process_dataset(self, dataset: Dataset): + """ + Override this method to implement your worker and process a single Arkindex dataset at once. + + :param dataset: The dataset to process. + """ + + def list_datasets(self) -> List[Dataset] | List[str]: + """ + Calls `list_process_datasets` if not is_read_only, + else simply give the list of IDs provided via CLI + """ + if self.is_read_only: + return list(map(str, self.args.dataset)) + + return self.list_process_datasets() + + def run(self): + self.configure() + + datasets: List[Dataset] | List[str] = self.list_datasets() + if not datasets: + logger.warning("No datasets to process, stopping.") + sys.exit(1) + + # Process every dataset + count = len(datasets) + failed = 0 + for i, item in enumerate(datasets, start=1): + dataset = None + try: + if not self.is_read_only: + # Just use the result of list_datasets as the dataset + dataset = item + else: + # Load dataset using the Arkindex API + dataset = Dataset(**self.request("RetrieveDataset", id=item)) + + if self.generator: + assert ( + dataset.state == DatasetState.Open.value + ), "When generating a new dataset, its state should be Open" + else: + assert ( + dataset.state == DatasetState.Complete.value + ), "When processing an existing dataset, its state should be Complete" + + if self.generator: + # Update the dataset state to Building + logger.info(f"Building {dataset} ({i}/{count})") + self.update_dataset_state(dataset, DatasetState.Building) + + # Process the dataset + self.process_dataset(dataset) + + if self.generator: + # Update the dataset state to Complete + logger.info(f"Completed {dataset} ({i}/{count})") + self.update_dataset_state(dataset, DatasetState.Complete) + except Exception as e: + # Handle errors occurring while retrieving, processing or patching the state for this dataset. + failed += 1 + + # Handle the case where we failed retrieving the dataset + dataset_id = dataset.id if dataset else item + + if isinstance(e, ErrorResponse): + message = f"An API error occurred while processing dataset {dataset_id}: {e.title} - {e.content}" + else: + message = ( + f"Failed running worker on dataset {dataset_id}: {repr(e)}" + ) + + logger.warning( + message, + exc_info=e if self.args.verbose else None, + ) + if dataset and self.generator: + # Try to update the state to Error regardless of the response + try: + self.update_dataset_state(dataset, DatasetState.Error) + except Exception: + pass + + if failed: + logger.error( + "Ran on {} dataset: {} completed, {} failed".format( + count, count - failed, failed + ) + ) + if failed >= count: # Everything failed! + sys.exit(1) + + +class DatasetExtractor(DatasetWorker): + def configure(self) -> None: + self.args: Namespace = self.parser.parse_args() + if self.is_read_only: + super().configure_for_developers() + else: + super().configure() + + if self.user_configuration: + logger.info("Overriding with user_configuration") + self.config.update(self.user_configuration) + + # Download corpus + self.download_latest_export() + + # Initialize db that will be written + self.configure_cache() + + # CachedImage downloaded and created in DB + self.cached_images = dict() + + # Where to save the downloaded images + self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data")) + logger.info(f"Images will be saved at `{self.image_folder}`.") + + def configure_cache(self) -> None: + """ + Create an SQLite database compatible with base-worker cache and initialize it. + """ + self.use_cache = True + self.cache_path: Path = self.args.database or self.work_dir / "db.sqlite" + # Remove previous execution result if present + self.cache_path.unlink(missing_ok=True) + + init_cache_db(self.cache_path) + + create_version_table() + + create_tables() + + def download_latest_export(self) -> None: + """ + Download the latest export of the current corpus. + Export must be in `"done"` state. + """ + try: + exports = list( + self.api_client.paginate( + "ListExports", + id=self.corpus_id, + ) + ) + except ErrorResponse as e: + logger.error( + f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}" + ) + raise e + + # Find the latest that is in "done" state + exports: List[dict] = sorted( + list(filter(lambda exp: exp["state"] == "done", exports)), + key=itemgetter("updated"), + reverse=True, + ) + assert ( + len(exports) > 0 + ), f"No available exports found for the corpus {self.corpus_id}." + + # Download latest export + try: + export_id: str = exports[0]["id"] + logger.info(f"Downloading export ({export_id})...") + self.export: _TemporaryFileWrapper = self.api_client.request( + "DownloadExport", + id=export_id, + ) + logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`") + open_database(self.export.name) + except ErrorResponse as e: + logger.error( + f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}" + ) + raise e + + def insert_classifications(self, element: CachedElement) -> None: + logger.info("Listing classifications") + classifications: list[CachedClassification] = [ + CachedClassification( + id=classification.id, + element=element, + class_name=classification.class_name, + confidence=classification.confidence, + state=classification.state, + worker_run_id=classification.worker_run, + ) + for classification in list_classifications(element.id) + ] + if classifications: + logger.info(f"Inserting {len(classifications)} classification(s)") + with cache_database.atomic(): + CachedClassification.bulk_create( + model_list=classifications, + batch_size=BULK_BATCH_SIZE, + ) + + def insert_transcriptions( + self, element: CachedElement + ) -> List[CachedTranscription]: + logger.info("Listing transcriptions") + transcriptions: list[CachedTranscription] = [ + CachedTranscription( + id=transcription.id, + element=element, + text=transcription.text, + confidence=transcription.confidence, + orientation=DEFAULT_TRANSCRIPTION_ORIENTATION, + worker_version_id=transcription.worker_version, + worker_run_id=transcription.worker_run, + ) + for transcription in list_transcriptions(element.id) + ] + if transcriptions: + logger.info(f"Inserting {len(transcriptions)} transcription(s)") + with cache_database.atomic(): + CachedTranscription.bulk_create( + model_list=transcriptions, + batch_size=BULK_BATCH_SIZE, + ) + return transcriptions + + def insert_entities(self, transcriptions: List[CachedTranscription]) -> None: + logger.info("Listing entities") + entities: List[CachedEntity] = [] + transcription_entities: List[CachedTranscriptionEntity] = [] + for transcription in transcriptions: + for transcription_entity in list_transcription_entities(transcription.id): + entity = CachedEntity( + id=transcription_entity.entity.id, + type=transcription_entity.entity.type.name, + name=transcription_entity.entity.name, + validated=transcription_entity.entity.validated, + metas=transcription_entity.entity.metas, + worker_run_id=transcription_entity.entity.worker_run, + ) + entities.append(entity) + transcription_entities.append( + CachedTranscriptionEntity( + id=transcription_entity.id, + transcription=transcription, + entity=entity, + offset=transcription_entity.offset, + length=transcription_entity.length, + confidence=transcription_entity.confidence, + worker_run_id=transcription_entity.worker_run, + ) + ) + if entities: + # First insert entities since they are foreign keys on transcription entities + logger.info(f"Inserting {len(entities)} entities") + with cache_database.atomic(): + CachedEntity.bulk_create( + model_list=entities, + batch_size=BULK_BATCH_SIZE, + ) + + if transcription_entities: + # Insert transcription entities + logger.info( + f"Inserting {len(transcription_entities)} transcription entities" + ) + with cache_database.atomic(): + CachedTranscriptionEntity.bulk_create( + model_list=transcription_entities, + batch_size=BULK_BATCH_SIZE, + ) + + def insert_element( + self, element: Element, parent_id: Optional[UUID] = None + ) -> None: + """ + Insert the given element in the cache database. + Its image will also be saved to disk, if it wasn't already. + + The insertion of an element includes: + - its classifications + - its transcriptions + - its transcriptions' entities (both Entity and TranscriptionEntity) + + :param element: Element to insert. + :param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements. + """ + logger.info(f"Processing element ({element.id})") + if element.image and element.image.id not in self.cached_images: + # Download image + logger.info("Downloading image") + download_image(url=build_image_url(element)).save( + self.image_folder / f"{element.image.id}.jpg" + ) + # Insert image + logger.info("Inserting image") + # Store images in case some other elements use it as well + with cache_database.atomic(): + self.cached_images[element.image.id] = CachedImage.create( + id=element.image.id, + width=element.image.width, + height=element.image.height, + url=element.image.url, + ) + + # Insert element + logger.info("Inserting element") + with cache_database.atomic(): + cached_element: CachedElement = CachedElement.create( + id=element.id, + parent_id=parent_id, + type=element.type, + image=self.cached_images[element.image.id] if element.image else None, + polygon=element.polygon, + rotation_angle=element.rotation_angle, + mirrored=element.mirrored, + worker_version_id=element.worker_version, + worker_run_id=element.worker_run, + confidence=element.confidence, + ) + + # Insert classifications + self.insert_classifications(cached_element) + + # Insert transcriptions + transcriptions: List[CachedTranscription] = self.insert_transcriptions( + cached_element + ) + + # Insert entities + self.insert_entities(transcriptions) + + def process_set(self, set_name: str, elements: List[Element]) -> None: + logger.info( + f"Filling the cache with information from elements in the set {set_name}" + ) + + # First list all pages + nb_elements: int = len(elements) + for idx, element in enumerate(elements, start=1): + logger.info(f"Processing `{set_name}` element ({idx}/{nb_elements})") + + # Insert page + self.insert_element(element) + + # List children + children = list_children(element.id) + nb_children: int = children.count() + for child_idx, child in enumerate(children, start=1): + logger.info(f"Processing child ({child_idx}/{nb_children})") + # Insert child + self.insert_element(child, parent_id=element.id) + + def process_dataset(self, dataset: Dataset): + # Iterate over given sets + for set_name, elements in self.list_dataset_elements_per_set(dataset): + self.process_set(set_name, elements) + + # TAR + ZSTD Image folder and store as task artifact + zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd" + logger.info(f"Compressing the images to {zstd_archive_path}") + create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path) + + +def main(): + DatasetExtractor( + description="Fill base-worker cache with information about dataset and extract images", + generator=True, + ).run() + + +if __name__ == "__main__": + main()