Skip to content
Snippets Groups Projects

New DatasetExtractor using a DatasetWorker

Merged Eva Bardou requested to merge dataset-worker into main
All threads resolved!
1 file
+ 6
139
Compare changes
  • Side-by-side
  • Inline
# -*- coding: utf-8 -*-
import logging
import sys
import tempfile
from argparse import Namespace
from itertools import groupby
from operator import itemgetter
from pathlib import Path
from tempfile import _TemporaryFileWrapper
from typing import Iterator, List, Optional, Tuple
from typing import List, Optional
from uuid import UUID
from apistar.exceptions import ErrorResponse
@@ -29,8 +27,7 @@ from arkindex_worker.image import download_image
from arkindex_worker.models import Dataset
from arkindex_worker.models import Element as WorkerElement
from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker.base import BaseWorker
from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
from arkindex_worker.worker import DatasetWorker
from worker_generic_training_dataset.db import (
list_classifications,
list_transcription_entities,
@@ -45,139 +42,8 @@ BULK_BATCH_SIZE = 50
DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
class DatasetWorker(BaseWorker, DatasetMixin):
def __init__(
self,
description: str = "Arkindex Elements Worker",
support_cache: bool = False,
generator: bool = False,
):
super().__init__(description, support_cache)
self.parser.add_argument(
"--dataset",
type=UUID,
nargs="+",
help="One or more Arkindex dataset ID",
)
self.generator = generator
def list_dataset_elements_per_split(
self, dataset: Dataset
) -> Iterator[Tuple[str, List[Element]]]:
"""
Calls `list_dataset_elements` but returns results grouped by Set
"""
def format_element(element: Tuple[str, WorkerElement]) -> Element:
return retrieve_element(element[1].id)
def format_split(
split: Tuple[str, Iterator[Tuple[str, WorkerElement]]]
) -> Tuple[str, List[Element]]:
return (split[0], list(map(format_element, list(split[1]))))
return map(
format_split,
groupby(
sorted(self.list_dataset_elements(dataset), key=itemgetter(0)),
key=itemgetter(0),
),
)
def process_dataset(self, dataset: Dataset):
"""
Override this method to implement your worker and process a single Arkindex dataset at once.
:param dataset: The dataset to process.
"""
def list_datasets(self) -> Iterator[Dataset] | Iterator[str]:
"""
Calls `list_process_datasets` if not is_read_only,
else simply give the list of IDs provided via CLI
"""
if self.is_read_only:
return map(str, self.args.dataset)
return self.list_process_datasets()
def run(self):
self.configure()
datasets: Iterator[Dataset] | Iterator[str] = list(self.list_datasets())
if not datasets:
logger.warning("No datasets to process, stopping.")
sys.exit(1)
# Process every dataset
count = len(datasets)
failed = 0
for i, item in enumerate(datasets, start=1):
dataset = None
try:
if not self.is_read_only:
# Just use the result of list_datasets as the dataset
dataset = item
else:
# Load dataset using the Arkindex API
dataset = Dataset(**self.request("RetrieveDataset", id=item))
if self.generator:
assert (
dataset.state == DatasetState.Open.value
), "When generating a new dataset, its state should be Open."
else:
assert (
dataset.state == DatasetState.Complete.value
), "When processing an existing dataset, its state should be Complete."
if self.generator:
# Update the dataset state to Building
logger.info(f"Building {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Building)
# Process the dataset
self.process_dataset(dataset)
if self.generator:
# Update the dataset state to Complete
logger.info(f"Completed {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Complete)
except Exception as e:
# Handle errors occurring while retrieving, processing or patching the state for this dataset.
failed += 1
# Handle the case where we failed retrieving the dataset
dataset_id = dataset.id if dataset else item
if isinstance(e, ErrorResponse):
message = f"An API error occurred while processing dataset {dataset_id}: {e.title} - {e.content}"
else:
message = (
f"Failed running worker on dataset {dataset_id}: {repr(e)}"
)
logger.warning(
message,
exc_info=e if self.args.verbose else None,
)
if dataset and self.generator:
# Try to update the state to Error regardless of the response
try:
self.update_dataset_state(dataset, DatasetState.Error)
except Exception:
pass
if failed:
logger.error(
"Ran on {} dataset: {} completed, {} failed".format(
count, count - failed, failed
)
)
if failed >= count: # Everything failed!
sys.exit(1)
def _format_element(element: WorkerElement) -> Element:
return retrieve_element(element.id)
class DatasetExtractor(DatasetWorker):
@@ -440,7 +306,8 @@ class DatasetExtractor(DatasetWorker):
def process_dataset(self, dataset: Dataset):
# Iterate over given splits
for split_name, elements in self.list_dataset_elements_per_split(dataset):
self.process_split(split_name, elements)
casted_elements = list(map(_format_element, elements))
self.process_split(split_name, casted_elements)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"
Loading