From 2ce6237adc7336ca1f5046a77a6cf2068dadf00e Mon Sep 17 00:00:00 2001 From: EvaBardou <bardou@teklia.com> Date: Thu, 19 Oct 2023 16:58:27 +0200 Subject: [PATCH] Use DatasetWorker from arkindex-base-worker --- worker_generic_training_dataset/worker.py | 145 +--------------------- 1 file changed, 6 insertions(+), 139 deletions(-) diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index c189f77..93a4c75 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -1,13 +1,11 @@ # -*- coding: utf-8 -*- import logging -import sys import tempfile from argparse import Namespace -from itertools import groupby from operator import itemgetter from pathlib import Path from tempfile import _TemporaryFileWrapper -from typing import Iterator, List, Optional, Tuple +from typing import List, Optional from uuid import UUID from apistar.exceptions import ErrorResponse @@ -29,8 +27,7 @@ from arkindex_worker.image import download_image from arkindex_worker.models import Dataset from arkindex_worker.models import Element as WorkerElement from arkindex_worker.utils import create_tar_zst_archive -from arkindex_worker.worker.base import BaseWorker -from arkindex_worker.worker.dataset import DatasetMixin, DatasetState +from arkindex_worker.worker import DatasetWorker from worker_generic_training_dataset.db import ( list_classifications, list_transcription_entities, @@ -45,139 +42,8 @@ BULK_BATCH_SIZE = 50 DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr" -class DatasetWorker(BaseWorker, DatasetMixin): - def __init__( - self, - description: str = "Arkindex Elements Worker", - support_cache: bool = False, - generator: bool = False, - ): - super().__init__(description, support_cache) - - self.parser.add_argument( - "--dataset", - type=UUID, - nargs="+", - help="One or more Arkindex dataset ID", - ) - - self.generator = generator - - def list_dataset_elements_per_split( - self, dataset: Dataset - ) -> Iterator[Tuple[str, List[Element]]]: - """ - Calls `list_dataset_elements` but returns results grouped by Set - """ - - def format_element(element: Tuple[str, WorkerElement]) -> Element: - return retrieve_element(element[1].id) - - def format_split( - split: Tuple[str, Iterator[Tuple[str, WorkerElement]]] - ) -> Tuple[str, List[Element]]: - return (split[0], list(map(format_element, list(split[1])))) - - return map( - format_split, - groupby( - sorted(self.list_dataset_elements(dataset), key=itemgetter(0)), - key=itemgetter(0), - ), - ) - - def process_dataset(self, dataset: Dataset): - """ - Override this method to implement your worker and process a single Arkindex dataset at once. - - :param dataset: The dataset to process. - """ - - def list_datasets(self) -> Iterator[Dataset] | Iterator[str]: - """ - Calls `list_process_datasets` if not is_read_only, - else simply give the list of IDs provided via CLI - """ - if self.is_read_only: - return map(str, self.args.dataset) - - return self.list_process_datasets() - - def run(self): - self.configure() - - datasets: Iterator[Dataset] | Iterator[str] = list(self.list_datasets()) - if not datasets: - logger.warning("No datasets to process, stopping.") - sys.exit(1) - - # Process every dataset - count = len(datasets) - failed = 0 - for i, item in enumerate(datasets, start=1): - dataset = None - try: - if not self.is_read_only: - # Just use the result of list_datasets as the dataset - dataset = item - else: - # Load dataset using the Arkindex API - dataset = Dataset(**self.request("RetrieveDataset", id=item)) - - if self.generator: - assert ( - dataset.state == DatasetState.Open.value - ), "When generating a new dataset, its state should be Open." - else: - assert ( - dataset.state == DatasetState.Complete.value - ), "When processing an existing dataset, its state should be Complete." - - if self.generator: - # Update the dataset state to Building - logger.info(f"Building {dataset} ({i}/{count})") - self.update_dataset_state(dataset, DatasetState.Building) - - # Process the dataset - self.process_dataset(dataset) - - if self.generator: - # Update the dataset state to Complete - logger.info(f"Completed {dataset} ({i}/{count})") - self.update_dataset_state(dataset, DatasetState.Complete) - except Exception as e: - # Handle errors occurring while retrieving, processing or patching the state for this dataset. - failed += 1 - - # Handle the case where we failed retrieving the dataset - dataset_id = dataset.id if dataset else item - - if isinstance(e, ErrorResponse): - message = f"An API error occurred while processing dataset {dataset_id}: {e.title} - {e.content}" - else: - message = ( - f"Failed running worker on dataset {dataset_id}: {repr(e)}" - ) - - logger.warning( - message, - exc_info=e if self.args.verbose else None, - ) - if dataset and self.generator: - # Try to update the state to Error regardless of the response - try: - self.update_dataset_state(dataset, DatasetState.Error) - except Exception: - pass - - if failed: - logger.error( - "Ran on {} dataset: {} completed, {} failed".format( - count, count - failed, failed - ) - ) - if failed >= count: # Everything failed! - sys.exit(1) +def _format_element(element: WorkerElement) -> Element: + return retrieve_element(element.id) class DatasetExtractor(DatasetWorker): @@ -440,7 +306,8 @@ class DatasetExtractor(DatasetWorker): def process_dataset(self, dataset: Dataset): # Iterate over given splits for split_name, elements in self.list_dataset_elements_per_split(dataset): - self.process_split(split_name, elements) + casted_elements = list(map(_format_element, elements)) + self.process_split(split_name, casted_elements) # TAR + ZSTD Image folder and store as task artifact zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd" -- GitLab