Skip to content
Snippets Groups Projects

New DatasetExtractor using a DatasetWorker

Merged Eva Bardou requested to merge dataset-worker into main
All threads resolved!
1 file
+ 6
139
Compare changes
  • Side-by-side
  • Inline
# -*- coding: utf-8 -*-
import logging
import operator
import tempfile
from argparse import Namespace
from operator import itemgetter
from pathlib import Path
from tempfile import _TemporaryFileWrapper
from typing import List, Optional
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_export import Element, Image, open_database
from arkindex_export import Element, open_database
from arkindex_export.queries import list_children
from arkindex_worker.cache import (
CachedClassification,
@@ -24,8 +24,10 @@ from arkindex_worker.cache import (
from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db
from arkindex_worker.image import download_image
from arkindex_worker.models import Dataset
from arkindex_worker.models import Element as WorkerElement
from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker.base import BaseWorker
from arkindex_worker.worker import DatasetWorker
from worker_generic_training_dataset.db import (
list_classifications,
list_transcription_entities,
@@ -40,7 +42,11 @@ BULK_BATCH_SIZE = 50
DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
class DatasetExtractor(BaseWorker):
def _format_element(element: WorkerElement) -> Element:
return retrieve_element(element.id)
class DatasetExtractor(DatasetWorker):
def configure(self) -> None:
self.args: Namespace = self.parser.parse_args()
if self.is_read_only:
@@ -52,12 +58,13 @@ class DatasetExtractor(BaseWorker):
logger.info("Overriding with user_configuration")
self.config.update(self.user_configuration)
# Read process information
self.read_training_related_information()
# Download corpus
self.download_latest_export()
def configure_storage(self) -> None:
self.data_folder = tempfile.TemporaryDirectory(suffix="-arkindex-data")
self.data_folder_path = Path(self.data_folder.name)
# Initialize db that will be written
self.configure_cache()
@@ -65,39 +72,17 @@ class DatasetExtractor(BaseWorker):
self.cached_images = dict()
# Where to save the downloaded images
self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data"))
logger.info(f"Images will be saved at `{self.image_folder}`.")
def read_training_related_information(self) -> None:
"""
Read from process information
- train_folder_id
- validation_folder_id
- test_folder_id (optional)
"""
logger.info("Retrieving information from process_information")
train_folder_id = self.process_information.get("train_folder_id")
assert train_folder_id, "A training folder id is necessary to use this worker"
self.training_folder_id = UUID(train_folder_id)
val_folder_id = self.process_information.get("validation_folder_id")
assert val_folder_id, "A validation folder id is necessary to use this worker"
self.validation_folder_id = UUID(val_folder_id)
test_folder_id = self.process_information.get("test_folder_id")
self.testing_folder_id: UUID | None = (
UUID(test_folder_id) if test_folder_id else None
)
self.images_folder = self.data_folder_path / "images"
self.images_folder.mkdir(parents=True)
logger.info(f"Images will be saved at `{self.images_folder}`.")
def configure_cache(self) -> None:
"""
Create an SQLite database compatible with base-worker cache and initialize it.
"""
self.use_cache = True
self.cache_path: Path = self.args.database or self.work_dir / "db.sqlite"
# Remove previous execution result if present
self.cache_path.unlink(missing_ok=True)
self.cache_path: Path = self.data_folder_path / "db.sqlite"
logger.info(f"Cached database will be saved at `{self.cache_path}`.")
init_cache_db(self.cache_path)
@@ -126,7 +111,7 @@ class DatasetExtractor(BaseWorker):
# Find the latest that is in "done" state
exports: List[dict] = sorted(
list(filter(lambda exp: exp["state"] == "done", exports)),
key=operator.itemgetter("updated"),
key=itemgetter("updated"),
reverse=True,
)
assert (
@@ -261,7 +246,7 @@ class DatasetExtractor(BaseWorker):
# Download image
logger.info("Downloading image")
download_image(url=build_image_url(element)).save(
self.image_folder / f"{element.image.id}.jpg"
self.images_folder / f"{element.image.id}.jpg"
)
# Insert image
logger.info("Inserting image")
@@ -301,60 +286,49 @@ class DatasetExtractor(BaseWorker):
# Insert entities
self.insert_entities(transcriptions)
def process_split(self, split_name: str, split_id: UUID) -> None:
"""
Insert all elements under the given parent folder (all queries are recursive).
- `page` elements are linked to this folder (via parent_id foreign key)
- `page` element children are linked to their `page` parent (via parent_id foreign key)
"""
def process_split(self, split_name: str, elements: List[Element]) -> None:
logger.info(
f"Filling the Base-Worker cache with information from children under element ({split_id})"
f"Filling the cache with information from elements in the split {split_name}"
)
# Fill cache
# Retrieve parent and create parent
parent: Element = retrieve_element(split_id)
self.insert_element(parent)
# First list all pages
pages = list_children(split_id).join(Image).where(Element.type == "page")
nb_pages: int = pages.count()
for idx, page in enumerate(pages, start=1):
logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})")
nb_elements: int = len(elements)
for idx, element in enumerate(elements, start=1):
logger.info(f"Processing `{split_name}` element ({idx}/{nb_elements})")
# Insert page
self.insert_element(page, parent_id=split_id)
self.insert_element(element)
# List children
children = list_children(page.id)
children = list_children(element.id)
nb_children: int = children.count()
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx}/{nb_children})")
# Insert child
self.insert_element(child, parent_id=page.id)
def run(self):
self.configure()
# Iterate over given split
for split_name, split_id in [
("Train", self.training_folder_id),
("Validation", self.validation_folder_id),
("Test", self.testing_folder_id),
]:
if not split_id:
continue
self.process_split(split_name, split_id)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path: Path = self.work_dir / "arkindex_data.zstd"
self.insert_element(child, parent_id=element.id)
def process_dataset(self, dataset: Dataset):
# Configure temporary storage for the dataset data (cache + images)
self.configure_storage()
# Iterate over given splits
for split_name, elements in self.list_dataset_elements_per_split(dataset):
casted_elements = list(map(_format_element, elements))
self.process_split(split_name, casted_elements)
# TAR + ZSTD the cache and the images folder, and store as task artifact
zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"
logger.info(f"Compressing the images to {zstd_archive_path}")
create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path)
create_tar_zst_archive(
source=self.data_folder_path, destination=zstd_archive_path
)
self.data_folder.cleanup()
def main():
DatasetExtractor(
description="Fill base-worker cache with information about dataset and extract images",
support_cache=True,
generator=True,
).run()
Loading