diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 532104c4005745c0798b4b9d0de178f099f489d2..e29b7731907d305125442eb6d30915a8f65f5166 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -27,6 +27,7 @@ from arkindex_worker.cache import db as cache_database from arkindex_worker.cache import init_cache_db from arkindex_worker.image import download_image from arkindex_worker.models import Dataset +from arkindex_worker.models import Element as WorkerElement from arkindex_worker.utils import create_tar_zst_archive from arkindex_worker.worker.base import BaseWorker from arkindex_worker.worker.dataset import DatasetMixin, DatasetState @@ -61,27 +62,27 @@ class DatasetWorker(BaseWorker, DatasetMixin): self.generator = generator - def list_dataset_elements_per_set( + def list_dataset_elements_per_split( self, dataset: Dataset - ) -> Iterator[Tuple[str, Element]]: + ) -> Iterator[Tuple[str, List[Element]]]: """ Calls `list_dataset_elements` but returns results grouped by Set """ - def format_element(element): + def format_element(element: Tuple[str, WorkerElement]) -> Element: return Element.get(Element.id == element[1].id) - def format_set(set): - return (set[0], list(map(format_element, list(set[1])))) - - return list( - map( - format_set, - groupby( - sorted(self.list_dataset_elements(dataset), key=itemgetter(0)), - key=itemgetter(0), - ), - ) + def format_split( + split: Tuple[str, Iterator[Tuple[str, WorkerElement]]] + ) -> Tuple[str, List[Element]]: + return (split[0], list(map(format_element, list(split[1])))) + + return map( + format_split, + groupby( + sorted(self.list_dataset_elements(dataset), key=itemgetter(0)), + key=itemgetter(0), + ), ) def process_dataset(self, dataset: Dataset): @@ -91,20 +92,20 @@ class DatasetWorker(BaseWorker, DatasetMixin): :param dataset: The dataset to process. """ - def list_datasets(self) -> List[Dataset] | List[str]: + def list_datasets(self) -> Iterator[Dataset] | Iterator[str]: """ Calls `list_process_datasets` if not is_read_only, else simply give the list of IDs provided via CLI """ if self.is_read_only: - return list(map(str, self.args.dataset)) + return map(str, self.args.dataset) return self.list_process_datasets() def run(self): self.configure() - datasets: List[Dataset] | List[str] = self.list_datasets() + datasets: Iterator[Dataset] | Iterator[str] = list(self.list_datasets()) if not datasets: logger.warning("No datasets to process, stopping.") sys.exit(1) @@ -125,11 +126,11 @@ class DatasetWorker(BaseWorker, DatasetMixin): if self.generator: assert ( dataset.state == DatasetState.Open.value - ), "When generating a new dataset, its state should be Open" + ), "When generating a new dataset, its state should be Open." else: assert ( dataset.state == DatasetState.Complete.value - ), "When processing an existing dataset, its state should be Complete" + ), "When processing an existing dataset, its state should be Complete." if self.generator: # Update the dataset state to Building @@ -414,15 +415,15 @@ class DatasetExtractor(DatasetWorker): # Insert entities self.insert_entities(transcriptions) - def process_set(self, set_name: str, elements: List[Element]) -> None: + def process_split(self, split_name: str, elements: List[Element]) -> None: logger.info( - f"Filling the cache with information from elements in the set {set_name}" + f"Filling the cache with information from elements in the split {split_name}" ) # First list all pages nb_elements: int = len(elements) for idx, element in enumerate(elements, start=1): - logger.info(f"Processing `{set_name}` element ({idx}/{nb_elements})") + logger.info(f"Processing `{split_name}` element ({idx}/{nb_elements})") # Insert page self.insert_element(element) @@ -436,9 +437,9 @@ class DatasetExtractor(DatasetWorker): self.insert_element(child, parent_id=element.id) def process_dataset(self, dataset: Dataset): - # Iterate over given sets - for set_name, elements in self.list_dataset_elements_per_set(dataset): - self.process_set(set_name, elements) + # Iterate over given splits + for split_name, elements in self.list_dataset_elements_per_split(dataset): + self.process_split(split_name, elements) # TAR + ZSTD Image folder and store as task artifact zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"