Skip to content
Snippets Groups Projects

New DatasetExtractor using a DatasetWorker

Merged Eva Bardou requested to merge dataset-worker into main
All threads resolved!
1 file
+ 26
25
Compare changes
  • Side-by-side
  • Inline
@@ -27,6 +27,7 @@ from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db
from arkindex_worker.image import download_image
from arkindex_worker.models import Dataset
from arkindex_worker.models import Element as WorkerElement
from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker.base import BaseWorker
from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
@@ -61,27 +62,27 @@ class DatasetWorker(BaseWorker, DatasetMixin):
self.generator = generator
def list_dataset_elements_per_set(
def list_dataset_elements_per_split(
self, dataset: Dataset
) -> Iterator[Tuple[str, Element]]:
) -> Iterator[Tuple[str, List[Element]]]:
"""
Calls `list_dataset_elements` but returns results grouped by Set
"""
def format_element(element):
def format_element(element: Tuple[str, WorkerElement]) -> Element:
return Element.get(Element.id == element[1].id)
def format_set(set):
return (set[0], list(map(format_element, list(set[1]))))
return list(
map(
format_set,
groupby(
sorted(self.list_dataset_elements(dataset), key=itemgetter(0)),
key=itemgetter(0),
),
)
def format_split(
split: Tuple[str, Iterator[Tuple[str, WorkerElement]]]
) -> Tuple[str, List[Element]]:
return (split[0], list(map(format_element, list(split[1]))))
return map(
format_split,
groupby(
sorted(self.list_dataset_elements(dataset), key=itemgetter(0)),
key=itemgetter(0),
),
)
def process_dataset(self, dataset: Dataset):
@@ -91,20 +92,20 @@ class DatasetWorker(BaseWorker, DatasetMixin):
:param dataset: The dataset to process.
"""
def list_datasets(self) -> List[Dataset] | List[str]:
def list_datasets(self) -> Iterator[Dataset] | Iterator[str]:
"""
Calls `list_process_datasets` if not is_read_only,
else simply give the list of IDs provided via CLI
"""
if self.is_read_only:
return list(map(str, self.args.dataset))
return map(str, self.args.dataset)
return self.list_process_datasets()
def run(self):
self.configure()
datasets: List[Dataset] | List[str] = self.list_datasets()
datasets: Iterator[Dataset] | Iterator[str] = list(self.list_datasets())
if not datasets:
logger.warning("No datasets to process, stopping.")
sys.exit(1)
@@ -125,11 +126,11 @@ class DatasetWorker(BaseWorker, DatasetMixin):
if self.generator:
assert (
dataset.state == DatasetState.Open.value
), "When generating a new dataset, its state should be Open"
), "When generating a new dataset, its state should be Open."
else:
assert (
dataset.state == DatasetState.Complete.value
), "When processing an existing dataset, its state should be Complete"
), "When processing an existing dataset, its state should be Complete."
if self.generator:
# Update the dataset state to Building
@@ -414,15 +415,15 @@ class DatasetExtractor(DatasetWorker):
# Insert entities
self.insert_entities(transcriptions)
def process_set(self, set_name: str, elements: List[Element]) -> None:
def process_split(self, split_name: str, elements: List[Element]) -> None:
logger.info(
f"Filling the cache with information from elements in the set {set_name}"
f"Filling the cache with information from elements in the split {split_name}"
)
# First list all pages
nb_elements: int = len(elements)
for idx, element in enumerate(elements, start=1):
logger.info(f"Processing `{set_name}` element ({idx}/{nb_elements})")
logger.info(f"Processing `{split_name}` element ({idx}/{nb_elements})")
# Insert page
self.insert_element(element)
@@ -436,9 +437,9 @@ class DatasetExtractor(DatasetWorker):
self.insert_element(child, parent_id=element.id)
def process_dataset(self, dataset: Dataset):
# Iterate over given sets
for set_name, elements in self.list_dataset_elements_per_set(dataset):
self.process_set(set_name, elements)
# Iterate over given splits
for split_name, elements in self.list_dataset_elements_per_split(dataset):
self.process_split(split_name, elements)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"
Loading