Skip to content
Snippets Groups Projects
Commit 2ce6237a authored by Eva Bardou's avatar Eva Bardou :frog:
Browse files

Use DatasetWorker from arkindex-base-worker

parent 03aad434
No related branches found
No related tags found
1 merge request!8New DatasetExtractor using a DatasetWorker
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
import sys
import tempfile import tempfile
from argparse import Namespace from argparse import Namespace
from itertools import groupby
from operator import itemgetter from operator import itemgetter
from pathlib import Path from pathlib import Path
from tempfile import _TemporaryFileWrapper from tempfile import _TemporaryFileWrapper
from typing import Iterator, List, Optional, Tuple from typing import List, Optional
from uuid import UUID from uuid import UUID
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
...@@ -29,8 +27,7 @@ from arkindex_worker.image import download_image ...@@ -29,8 +27,7 @@ from arkindex_worker.image import download_image
from arkindex_worker.models import Dataset from arkindex_worker.models import Dataset
from arkindex_worker.models import Element as WorkerElement from arkindex_worker.models import Element as WorkerElement
from arkindex_worker.utils import create_tar_zst_archive from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker.base import BaseWorker from arkindex_worker.worker import DatasetWorker
from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
from worker_generic_training_dataset.db import ( from worker_generic_training_dataset.db import (
list_classifications, list_classifications,
list_transcription_entities, list_transcription_entities,
...@@ -45,139 +42,8 @@ BULK_BATCH_SIZE = 50 ...@@ -45,139 +42,8 @@ BULK_BATCH_SIZE = 50
DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr" DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
class DatasetWorker(BaseWorker, DatasetMixin): def _format_element(element: WorkerElement) -> Element:
def __init__( return retrieve_element(element.id)
self,
description: str = "Arkindex Elements Worker",
support_cache: bool = False,
generator: bool = False,
):
super().__init__(description, support_cache)
self.parser.add_argument(
"--dataset",
type=UUID,
nargs="+",
help="One or more Arkindex dataset ID",
)
self.generator = generator
def list_dataset_elements_per_split(
self, dataset: Dataset
) -> Iterator[Tuple[str, List[Element]]]:
"""
Calls `list_dataset_elements` but returns results grouped by Set
"""
def format_element(element: Tuple[str, WorkerElement]) -> Element:
return retrieve_element(element[1].id)
def format_split(
split: Tuple[str, Iterator[Tuple[str, WorkerElement]]]
) -> Tuple[str, List[Element]]:
return (split[0], list(map(format_element, list(split[1]))))
return map(
format_split,
groupby(
sorted(self.list_dataset_elements(dataset), key=itemgetter(0)),
key=itemgetter(0),
),
)
def process_dataset(self, dataset: Dataset):
"""
Override this method to implement your worker and process a single Arkindex dataset at once.
:param dataset: The dataset to process.
"""
def list_datasets(self) -> Iterator[Dataset] | Iterator[str]:
"""
Calls `list_process_datasets` if not is_read_only,
else simply give the list of IDs provided via CLI
"""
if self.is_read_only:
return map(str, self.args.dataset)
return self.list_process_datasets()
def run(self):
self.configure()
datasets: Iterator[Dataset] | Iterator[str] = list(self.list_datasets())
if not datasets:
logger.warning("No datasets to process, stopping.")
sys.exit(1)
# Process every dataset
count = len(datasets)
failed = 0
for i, item in enumerate(datasets, start=1):
dataset = None
try:
if not self.is_read_only:
# Just use the result of list_datasets as the dataset
dataset = item
else:
# Load dataset using the Arkindex API
dataset = Dataset(**self.request("RetrieveDataset", id=item))
if self.generator:
assert (
dataset.state == DatasetState.Open.value
), "When generating a new dataset, its state should be Open."
else:
assert (
dataset.state == DatasetState.Complete.value
), "When processing an existing dataset, its state should be Complete."
if self.generator:
# Update the dataset state to Building
logger.info(f"Building {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Building)
# Process the dataset
self.process_dataset(dataset)
if self.generator:
# Update the dataset state to Complete
logger.info(f"Completed {dataset} ({i}/{count})")
self.update_dataset_state(dataset, DatasetState.Complete)
except Exception as e:
# Handle errors occurring while retrieving, processing or patching the state for this dataset.
failed += 1
# Handle the case where we failed retrieving the dataset
dataset_id = dataset.id if dataset else item
if isinstance(e, ErrorResponse):
message = f"An API error occurred while processing dataset {dataset_id}: {e.title} - {e.content}"
else:
message = (
f"Failed running worker on dataset {dataset_id}: {repr(e)}"
)
logger.warning(
message,
exc_info=e if self.args.verbose else None,
)
if dataset and self.generator:
# Try to update the state to Error regardless of the response
try:
self.update_dataset_state(dataset, DatasetState.Error)
except Exception:
pass
if failed:
logger.error(
"Ran on {} dataset: {} completed, {} failed".format(
count, count - failed, failed
)
)
if failed >= count: # Everything failed!
sys.exit(1)
class DatasetExtractor(DatasetWorker): class DatasetExtractor(DatasetWorker):
...@@ -440,7 +306,8 @@ class DatasetExtractor(DatasetWorker): ...@@ -440,7 +306,8 @@ class DatasetExtractor(DatasetWorker):
def process_dataset(self, dataset: Dataset): def process_dataset(self, dataset: Dataset):
# Iterate over given splits # Iterate over given splits
for split_name, elements in self.list_dataset_elements_per_split(dataset): for split_name, elements in self.list_dataset_elements_per_split(dataset):
self.process_split(split_name, elements) casted_elements = list(map(_format_element, elements))
self.process_split(split_name, casted_elements)
# TAR + ZSTD Image folder and store as task artifact # TAR + ZSTD Image folder and store as task artifact
zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd" zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment