diff --git a/setup.py b/setup.py index 4f4d657fc81d35812fb2f4bff9741a22d97fcb44..81f4440affe641e4266ae08efa2af63c9b40c62c 100755 --- a/setup.py +++ b/setup.py @@ -47,11 +47,6 @@ setup( author="Teklia", author_email="contact@teklia.com", install_requires=parse_requirements(), - entry_points={ - "console_scripts": [ - f"{COMMAND}={MODULE}.worker:main", - "worker-generic-training-dataset-new=worker_generic_training_dataset.dataset_worker:main", - ] - }, + entry_points={"console_scripts": [f"{COMMAND}={MODULE}.worker:main"]}, packages=find_packages(), ) diff --git a/worker_generic_training_dataset/dataset_worker.py b/worker_generic_training_dataset/dataset_worker.py deleted file mode 100644 index a82c4649360a45447f5f8826a3c187cf4c6369e1..0000000000000000000000000000000000000000 --- a/worker_generic_training_dataset/dataset_worker.py +++ /dev/null @@ -1,458 +0,0 @@ -# -*- coding: utf-8 -*- -import logging -import sys -import tempfile -from argparse import Namespace -from itertools import groupby -from operator import itemgetter -from pathlib import Path -from tempfile import _TemporaryFileWrapper -from typing import Iterator, List, Optional, Tuple -from uuid import UUID - -from apistar.exceptions import ErrorResponse -from arkindex_export import Element, open_database -from arkindex_export.queries import list_children -from arkindex_worker.cache import ( - CachedClassification, - CachedElement, - CachedEntity, - CachedImage, - CachedTranscription, - CachedTranscriptionEntity, - create_tables, - create_version_table, -) -from arkindex_worker.cache import db as cache_database -from arkindex_worker.cache import init_cache_db -from arkindex_worker.image import download_image -from arkindex_worker.models import Dataset -from arkindex_worker.utils import create_tar_zst_archive -from arkindex_worker.worker.base import BaseWorker -from arkindex_worker.worker.dataset import DatasetMixin, DatasetState -from worker_generic_training_dataset.db import ( - list_classifications, - list_transcription_entities, - list_transcriptions, -) -from worker_generic_training_dataset.utils import build_image_url -from worker_generic_training_dataset.worker import ( - BULK_BATCH_SIZE, - DEFAULT_TRANSCRIPTION_ORIENTATION, -) - -logger: logging.Logger = logging.getLogger(__name__) - - -class DatasetWorker(BaseWorker, DatasetMixin): - def __init__( - self, - description: str = "Arkindex Elements Worker", - support_cache: bool = False, - generator: bool = False, - ): - super().__init__(description, support_cache) - - self.parser.add_argument( - "--dataset", - type=UUID, - nargs="+", - help="One or more Arkindex dataset ID", - ) - - self.generator = generator - - def list_dataset_elements_per_set( - self, dataset: Dataset - ) -> Iterator[Tuple[str, Element]]: - """ - Calls `list_dataset_elements` but returns results grouped by Set - """ - - def format_element(element): - return Element.get(Element.id == element[1].id) - - def format_set(set): - return (set[0], list(map(format_element, list(set[1])))) - - return list( - map( - format_set, - groupby( - sorted(self.list_dataset_elements(dataset), key=itemgetter(0)), - key=itemgetter(0), - ), - ) - ) - - def process_dataset(self, dataset: Dataset): - """ - Override this method to implement your worker and process a single Arkindex dataset at once. - - :param dataset: The dataset to process. - """ - - def list_datasets(self) -> List[Dataset] | List[str]: - """ - Calls `list_process_datasets` if not is_read_only, - else simply give the list of IDs provided via CLI - """ - if self.is_read_only: - return list(map(str, self.args.dataset)) - - return self.list_process_datasets() - - def run(self): - self.configure() - - datasets: List[Dataset] | List[str] = self.list_datasets() - if not datasets: - logger.warning("No datasets to process, stopping.") - sys.exit(1) - - # Process every dataset - count = len(datasets) - failed = 0 - for i, item in enumerate(datasets, start=1): - dataset = None - try: - if not self.is_read_only: - # Just use the result of list_datasets as the dataset - dataset = item - else: - # Load dataset using the Arkindex API - dataset = Dataset(**self.request("RetrieveDataset", id=item)) - - if self.generator: - assert ( - dataset.state == DatasetState.Open.value - ), "When generating a new dataset, its state should be Open" - else: - assert ( - dataset.state == DatasetState.Complete.value - ), "When processing an existing dataset, its state should be Complete" - - if self.generator: - # Update the dataset state to Building - logger.info(f"Building {dataset} ({i}/{count})") - self.update_dataset_state(dataset, DatasetState.Building) - - # Process the dataset - self.process_dataset(dataset) - - if self.generator: - # Update the dataset state to Complete - logger.info(f"Completed {dataset} ({i}/{count})") - self.update_dataset_state(dataset, DatasetState.Complete) - except Exception as e: - # Handle errors occurring while retrieving, processing or patching the state for this dataset. - failed += 1 - - # Handle the case where we failed retrieving the dataset - dataset_id = dataset.id if dataset else item - - if isinstance(e, ErrorResponse): - message = f"An API error occurred while processing dataset {dataset_id}: {e.title} - {e.content}" - else: - message = ( - f"Failed running worker on dataset {dataset_id}: {repr(e)}" - ) - - logger.warning( - message, - exc_info=e if self.args.verbose else None, - ) - if dataset and self.generator: - # Try to update the state to Error regardless of the response - try: - self.update_dataset_state(dataset, DatasetState.Error) - except Exception: - pass - - if failed: - logger.error( - "Ran on {} dataset: {} completed, {} failed".format( - count, count - failed, failed - ) - ) - if failed >= count: # Everything failed! - sys.exit(1) - - -class DatasetExtractor(DatasetWorker): - def configure(self) -> None: - self.args: Namespace = self.parser.parse_args() - if self.is_read_only: - super().configure_for_developers() - else: - super().configure() - - if self.user_configuration: - logger.info("Overriding with user_configuration") - self.config.update(self.user_configuration) - - # Download corpus - self.download_latest_export() - - # Initialize db that will be written - self.configure_cache() - - # CachedImage downloaded and created in DB - self.cached_images = dict() - - # Where to save the downloaded images - self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data")) - logger.info(f"Images will be saved at `{self.image_folder}`.") - - def configure_cache(self) -> None: - """ - Create an SQLite database compatible with base-worker cache and initialize it. - """ - self.use_cache = True - self.cache_path: Path = self.args.database or self.work_dir / "db.sqlite" - # Remove previous execution result if present - self.cache_path.unlink(missing_ok=True) - - init_cache_db(self.cache_path) - - create_version_table() - - create_tables() - - def download_latest_export(self) -> None: - """ - Download the latest export of the current corpus. - Export must be in `"done"` state. - """ - try: - exports = list( - self.api_client.paginate( - "ListExports", - id=self.corpus_id, - ) - ) - except ErrorResponse as e: - logger.error( - f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}" - ) - raise e - - # Find the latest that is in "done" state - exports: List[dict] = sorted( - list(filter(lambda exp: exp["state"] == "done", exports)), - key=itemgetter("updated"), - reverse=True, - ) - assert ( - len(exports) > 0 - ), f"No available exports found for the corpus {self.corpus_id}." - - # Download latest export - try: - export_id: str = exports[0]["id"] - logger.info(f"Downloading export ({export_id})...") - self.export: _TemporaryFileWrapper = self.api_client.request( - "DownloadExport", - id=export_id, - ) - logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`") - open_database(self.export.name) - except ErrorResponse as e: - logger.error( - f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}" - ) - raise e - - def insert_classifications(self, element: CachedElement) -> None: - logger.info("Listing classifications") - classifications: list[CachedClassification] = [ - CachedClassification( - id=classification.id, - element=element, - class_name=classification.class_name, - confidence=classification.confidence, - state=classification.state, - worker_run_id=classification.worker_run, - ) - for classification in list_classifications(element.id) - ] - if classifications: - logger.info(f"Inserting {len(classifications)} classification(s)") - with cache_database.atomic(): - CachedClassification.bulk_create( - model_list=classifications, - batch_size=BULK_BATCH_SIZE, - ) - - def insert_transcriptions( - self, element: CachedElement - ) -> List[CachedTranscription]: - logger.info("Listing transcriptions") - transcriptions: list[CachedTranscription] = [ - CachedTranscription( - id=transcription.id, - element=element, - text=transcription.text, - confidence=transcription.confidence, - orientation=DEFAULT_TRANSCRIPTION_ORIENTATION, - worker_version_id=transcription.worker_version, - worker_run_id=transcription.worker_run, - ) - for transcription in list_transcriptions(element.id) - ] - if transcriptions: - logger.info(f"Inserting {len(transcriptions)} transcription(s)") - with cache_database.atomic(): - CachedTranscription.bulk_create( - model_list=transcriptions, - batch_size=BULK_BATCH_SIZE, - ) - return transcriptions - - def insert_entities(self, transcriptions: List[CachedTranscription]) -> None: - logger.info("Listing entities") - entities: List[CachedEntity] = [] - transcription_entities: List[CachedTranscriptionEntity] = [] - for transcription in transcriptions: - for transcription_entity in list_transcription_entities(transcription.id): - entity = CachedEntity( - id=transcription_entity.entity.id, - type=transcription_entity.entity.type.name, - name=transcription_entity.entity.name, - validated=transcription_entity.entity.validated, - metas=transcription_entity.entity.metas, - worker_run_id=transcription_entity.entity.worker_run, - ) - entities.append(entity) - transcription_entities.append( - CachedTranscriptionEntity( - id=transcription_entity.id, - transcription=transcription, - entity=entity, - offset=transcription_entity.offset, - length=transcription_entity.length, - confidence=transcription_entity.confidence, - worker_run_id=transcription_entity.worker_run, - ) - ) - if entities: - # First insert entities since they are foreign keys on transcription entities - logger.info(f"Inserting {len(entities)} entities") - with cache_database.atomic(): - CachedEntity.bulk_create( - model_list=entities, - batch_size=BULK_BATCH_SIZE, - ) - - if transcription_entities: - # Insert transcription entities - logger.info( - f"Inserting {len(transcription_entities)} transcription entities" - ) - with cache_database.atomic(): - CachedTranscriptionEntity.bulk_create( - model_list=transcription_entities, - batch_size=BULK_BATCH_SIZE, - ) - - def insert_element( - self, element: Element, parent_id: Optional[UUID] = None - ) -> None: - """ - Insert the given element in the cache database. - Its image will also be saved to disk, if it wasn't already. - - The insertion of an element includes: - - its classifications - - its transcriptions - - its transcriptions' entities (both Entity and TranscriptionEntity) - - :param element: Element to insert. - :param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements. - """ - logger.info(f"Processing element ({element.id})") - if element.image and element.image.id not in self.cached_images: - # Download image - logger.info("Downloading image") - download_image(url=build_image_url(element)).save( - self.image_folder / f"{element.image.id}.jpg" - ) - # Insert image - logger.info("Inserting image") - # Store images in case some other elements use it as well - with cache_database.atomic(): - self.cached_images[element.image.id] = CachedImage.create( - id=element.image.id, - width=element.image.width, - height=element.image.height, - url=element.image.url, - ) - - # Insert element - logger.info("Inserting element") - with cache_database.atomic(): - cached_element: CachedElement = CachedElement.create( - id=element.id, - parent_id=parent_id, - type=element.type, - image=self.cached_images[element.image.id] if element.image else None, - polygon=element.polygon, - rotation_angle=element.rotation_angle, - mirrored=element.mirrored, - worker_version_id=element.worker_version, - worker_run_id=element.worker_run, - confidence=element.confidence, - ) - - # Insert classifications - self.insert_classifications(cached_element) - - # Insert transcriptions - transcriptions: List[CachedTranscription] = self.insert_transcriptions( - cached_element - ) - - # Insert entities - self.insert_entities(transcriptions) - - def process_set(self, set_name: str, elements: List[Element]) -> None: - logger.info( - f"Filling the cache with information from elements in the set {set_name}" - ) - - # First list all pages - nb_elements: int = len(elements) - for idx, element in enumerate(elements, start=1): - logger.info(f"Processing `{set_name}` element ({idx}/{nb_elements})") - - # Insert page - self.insert_element(element) - - # List children - children = list_children(element.id) - nb_children: int = children.count() - for child_idx, child in enumerate(children, start=1): - logger.info(f"Processing child ({child_idx}/{nb_children})") - # Insert child - self.insert_element(child, parent_id=element.id) - - def process_dataset(self, dataset: Dataset): - # Iterate over given sets - for set_name, elements in self.list_dataset_elements_per_set(dataset): - self.process_set(set_name, elements) - - # TAR + ZSTD Image folder and store as task artifact - zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd" - logger.info(f"Compressing the images to {zstd_archive_path}") - create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path) - - -def main(): - DatasetExtractor( - description="Fill base-worker cache with information about dataset and extract images", - generator=True, - ).run() - - -if __name__ == "__main__": - main() diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index b20a04dae21d43471ced457e49172555b9d8013e..532104c4005745c0798b4b9d0de178f099f489d2 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -1,15 +1,17 @@ # -*- coding: utf-8 -*- import logging -import operator +import sys import tempfile from argparse import Namespace +from itertools import groupby +from operator import itemgetter from pathlib import Path from tempfile import _TemporaryFileWrapper -from typing import List, Optional +from typing import Iterator, List, Optional, Tuple from uuid import UUID from apistar.exceptions import ErrorResponse -from arkindex_export import Element, Image, open_database +from arkindex_export import Element, open_database from arkindex_export.queries import list_children from arkindex_worker.cache import ( CachedClassification, @@ -24,13 +26,14 @@ from arkindex_worker.cache import ( from arkindex_worker.cache import db as cache_database from arkindex_worker.cache import init_cache_db from arkindex_worker.image import download_image +from arkindex_worker.models import Dataset from arkindex_worker.utils import create_tar_zst_archive from arkindex_worker.worker.base import BaseWorker +from arkindex_worker.worker.dataset import DatasetMixin, DatasetState from worker_generic_training_dataset.db import ( list_classifications, list_transcription_entities, list_transcriptions, - retrieve_element, ) from worker_generic_training_dataset.utils import build_image_url @@ -40,7 +43,142 @@ BULK_BATCH_SIZE = 50 DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr" -class DatasetExtractor(BaseWorker): +class DatasetWorker(BaseWorker, DatasetMixin): + def __init__( + self, + description: str = "Arkindex Elements Worker", + support_cache: bool = False, + generator: bool = False, + ): + super().__init__(description, support_cache) + + self.parser.add_argument( + "--dataset", + type=UUID, + nargs="+", + help="One or more Arkindex dataset ID", + ) + + self.generator = generator + + def list_dataset_elements_per_set( + self, dataset: Dataset + ) -> Iterator[Tuple[str, Element]]: + """ + Calls `list_dataset_elements` but returns results grouped by Set + """ + + def format_element(element): + return Element.get(Element.id == element[1].id) + + def format_set(set): + return (set[0], list(map(format_element, list(set[1])))) + + return list( + map( + format_set, + groupby( + sorted(self.list_dataset_elements(dataset), key=itemgetter(0)), + key=itemgetter(0), + ), + ) + ) + + def process_dataset(self, dataset: Dataset): + """ + Override this method to implement your worker and process a single Arkindex dataset at once. + + :param dataset: The dataset to process. + """ + + def list_datasets(self) -> List[Dataset] | List[str]: + """ + Calls `list_process_datasets` if not is_read_only, + else simply give the list of IDs provided via CLI + """ + if self.is_read_only: + return list(map(str, self.args.dataset)) + + return self.list_process_datasets() + + def run(self): + self.configure() + + datasets: List[Dataset] | List[str] = self.list_datasets() + if not datasets: + logger.warning("No datasets to process, stopping.") + sys.exit(1) + + # Process every dataset + count = len(datasets) + failed = 0 + for i, item in enumerate(datasets, start=1): + dataset = None + try: + if not self.is_read_only: + # Just use the result of list_datasets as the dataset + dataset = item + else: + # Load dataset using the Arkindex API + dataset = Dataset(**self.request("RetrieveDataset", id=item)) + + if self.generator: + assert ( + dataset.state == DatasetState.Open.value + ), "When generating a new dataset, its state should be Open" + else: + assert ( + dataset.state == DatasetState.Complete.value + ), "When processing an existing dataset, its state should be Complete" + + if self.generator: + # Update the dataset state to Building + logger.info(f"Building {dataset} ({i}/{count})") + self.update_dataset_state(dataset, DatasetState.Building) + + # Process the dataset + self.process_dataset(dataset) + + if self.generator: + # Update the dataset state to Complete + logger.info(f"Completed {dataset} ({i}/{count})") + self.update_dataset_state(dataset, DatasetState.Complete) + except Exception as e: + # Handle errors occurring while retrieving, processing or patching the state for this dataset. + failed += 1 + + # Handle the case where we failed retrieving the dataset + dataset_id = dataset.id if dataset else item + + if isinstance(e, ErrorResponse): + message = f"An API error occurred while processing dataset {dataset_id}: {e.title} - {e.content}" + else: + message = ( + f"Failed running worker on dataset {dataset_id}: {repr(e)}" + ) + + logger.warning( + message, + exc_info=e if self.args.verbose else None, + ) + if dataset and self.generator: + # Try to update the state to Error regardless of the response + try: + self.update_dataset_state(dataset, DatasetState.Error) + except Exception: + pass + + if failed: + logger.error( + "Ran on {} dataset: {} completed, {} failed".format( + count, count - failed, failed + ) + ) + if failed >= count: # Everything failed! + sys.exit(1) + + +class DatasetExtractor(DatasetWorker): def configure(self) -> None: self.args: Namespace = self.parser.parse_args() if self.is_read_only: @@ -52,9 +190,6 @@ class DatasetExtractor(BaseWorker): logger.info("Overriding with user_configuration") self.config.update(self.user_configuration) - # Read process information - self.read_training_related_information() - # Download corpus self.download_latest_export() @@ -68,28 +203,6 @@ class DatasetExtractor(BaseWorker): self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data")) logger.info(f"Images will be saved at `{self.image_folder}`.") - def read_training_related_information(self) -> None: - """ - Read from process information - - train_folder_id - - validation_folder_id - - test_folder_id (optional) - """ - logger.info("Retrieving information from process_information") - - train_folder_id = self.process_information.get("train_folder_id") - assert train_folder_id, "A training folder id is necessary to use this worker" - self.training_folder_id = UUID(train_folder_id) - - val_folder_id = self.process_information.get("validation_folder_id") - assert val_folder_id, "A validation folder id is necessary to use this worker" - self.validation_folder_id = UUID(val_folder_id) - - test_folder_id = self.process_information.get("test_folder_id") - self.testing_folder_id: UUID | None = ( - UUID(test_folder_id) if test_folder_id else None - ) - def configure_cache(self) -> None: """ Create an SQLite database compatible with base-worker cache and initialize it. @@ -126,7 +239,7 @@ class DatasetExtractor(BaseWorker): # Find the latest that is in "done" state exports: List[dict] = sorted( list(filter(lambda exp: exp["state"] == "done", exports)), - key=operator.itemgetter("updated"), + key=itemgetter("updated"), reverse=True, ) assert ( @@ -301,52 +414,34 @@ class DatasetExtractor(BaseWorker): # Insert entities self.insert_entities(transcriptions) - def process_split(self, split_name: str, split_id: UUID) -> None: - """ - Insert all elements under the given parent folder (all queries are recursive). - - `page` elements are linked to this folder (via parent_id foreign key) - - `page` element children are linked to their `page` parent (via parent_id foreign key) - """ + def process_set(self, set_name: str, elements: List[Element]) -> None: logger.info( - f"Filling the Base-Worker cache with information from children under element ({split_id})" + f"Filling the cache with information from elements in the set {set_name}" ) - # Fill cache - # Retrieve parent and create parent - parent: Element = retrieve_element(split_id) - self.insert_element(parent) # First list all pages - pages = list_children(split_id).join(Image).where(Element.type == "page") - nb_pages: int = pages.count() - for idx, page in enumerate(pages, start=1): - logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})") + nb_elements: int = len(elements) + for idx, element in enumerate(elements, start=1): + logger.info(f"Processing `{set_name}` element ({idx}/{nb_elements})") # Insert page - self.insert_element(page, parent_id=split_id) + self.insert_element(element) # List children - children = list_children(page.id) + children = list_children(element.id) nb_children: int = children.count() for child_idx, child in enumerate(children, start=1): logger.info(f"Processing child ({child_idx}/{nb_children})") # Insert child - self.insert_element(child, parent_id=page.id) - - def run(self): - self.configure() + self.insert_element(child, parent_id=element.id) - # Iterate over given split - for split_name, split_id in [ - ("Train", self.training_folder_id), - ("Validation", self.validation_folder_id), - ("Test", self.testing_folder_id), - ]: - if not split_id: - continue - self.process_split(split_name, split_id) + def process_dataset(self, dataset: Dataset): + # Iterate over given sets + for set_name, elements in self.list_dataset_elements_per_set(dataset): + self.process_set(set_name, elements) # TAR + ZSTD Image folder and store as task artifact - zstd_archive_path: Path = self.work_dir / "arkindex_data.zstd" + zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd" logger.info(f"Compressing the images to {zstd_archive_path}") create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path) @@ -354,7 +449,7 @@ class DatasetExtractor(BaseWorker): def main(): DatasetExtractor( description="Fill base-worker cache with information about dataset and extract images", - support_cache=True, + generator=True, ).run()