Skip to content
Snippets Groups Projects
Commit f2988ecd authored by Eva Bardou's avatar Eva Bardou :frog: Committed by Yoann Schneider
Browse files

New DatasetExtractor using a DatasetWorker

parent 07969c74
No related branches found
No related tags found
1 merge request!8New DatasetExtractor using a DatasetWorker
......@@ -9,16 +9,3 @@ workers:
type: data-extract
docker:
build: Dockerfile
user_configuration:
train_folder_id:
type: string
title: ID of the training folder on Arkindex
required: true
validation_folder_id:
type: string
title: ID of the validation folder on Arkindex
required: true
test_folder_id:
type: string
title: ID of the testing folder on Arkindex
required: true
......@@ -11,34 +11,38 @@ from arkindex_worker.cache import (
CachedTranscription,
CachedTranscriptionEntity,
)
from worker_generic_training_dataset.db import retrieve_element
from worker_generic_training_dataset.worker import DatasetExtractor
def test_process_split(tmp_path, downloaded_images):
# Parent is train folder
parent_id: UUID = UUID("a0c4522d-2d80-4766-a01c-b9d686f41f6a")
worker = DatasetExtractor()
# Parse some arguments
worker.args = Namespace(database=None)
worker.data_folder_path = tmp_path
worker.configure_cache()
worker.cached_images = dict()
# Where to save the downloaded images
worker.image_folder = tmp_path
worker.process_split("train", parent_id)
worker.images_folder = tmp_path / "images"
worker.images_folder.mkdir(parents=True)
# Should have created 20 elements in total
assert CachedElement.select().count() == 20
first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
# Should have created two pages under root folder
assert (
CachedElement.select().where(CachedElement.parent_id == parent_id).count() == 2
worker.process_split(
"train",
[
retrieve_element(first_page_id),
retrieve_element(second_page_id),
],
)
first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
# Should have created 20 elements in total
assert CachedElement.select().count() == 19
# Should have created two pages at root
assert CachedElement.select().where(CachedElement.parent_id.is_null()).count() == 2
# Should have created 8 text_lines under first page
assert (
......@@ -78,11 +82,6 @@ def test_process_split(tmp_path, downloaded_images):
== f"https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2F{page_name}.png"
)
assert sorted(tmp_path.rglob("*")) == [
tmp_path / f"{first_image_id}.jpg",
tmp_path / f"{second_image_id}.jpg",
]
# Should have created 17 transcriptions
assert CachedTranscription.select().count() == 17
# Check transcription of first line on first page
......@@ -125,3 +124,11 @@ def test_process_split(tmp_path, downloaded_images):
assert tr_entity.length == 23
assert tr_entity.confidence == 1.0
assert tr_entity.worker_run_id is None
# Full structure of the archive
assert sorted(tmp_path.rglob("*")) == [
tmp_path / "db.sqlite",
tmp_path / "images",
tmp_path / "images" / f"{first_image_id}.jpg",
tmp_path / "images" / f"{second_image_id}.jpg",
]
# -*- coding: utf-8 -*-
import logging
import operator
import tempfile
from argparse import Namespace
from operator import itemgetter
from pathlib import Path
from tempfile import _TemporaryFileWrapper
from typing import List, Optional
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_export import Element, Image, open_database
from arkindex_export import Element, open_database
from arkindex_export.queries import list_children
from arkindex_worker.cache import (
CachedClassification,
......@@ -24,8 +24,10 @@ from arkindex_worker.cache import (
from arkindex_worker.cache import db as cache_database
from arkindex_worker.cache import init_cache_db
from arkindex_worker.image import download_image
from arkindex_worker.models import Dataset
from arkindex_worker.models import Element as WorkerElement
from arkindex_worker.utils import create_tar_zst_archive
from arkindex_worker.worker.base import BaseWorker
from arkindex_worker.worker import DatasetWorker
from worker_generic_training_dataset.db import (
list_classifications,
list_transcription_entities,
......@@ -40,7 +42,11 @@ BULK_BATCH_SIZE = 50
DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
class DatasetExtractor(BaseWorker):
def _format_element(element: WorkerElement) -> Element:
return retrieve_element(element.id)
class DatasetExtractor(DatasetWorker):
def configure(self) -> None:
self.args: Namespace = self.parser.parse_args()
if self.is_read_only:
......@@ -52,12 +58,13 @@ class DatasetExtractor(BaseWorker):
logger.info("Overriding with user_configuration")
self.config.update(self.user_configuration)
# Read process information
self.read_training_related_information()
# Download corpus
self.download_latest_export()
def configure_storage(self) -> None:
self.data_folder = tempfile.TemporaryDirectory(suffix="-arkindex-data")
self.data_folder_path = Path(self.data_folder.name)
# Initialize db that will be written
self.configure_cache()
......@@ -65,39 +72,17 @@ class DatasetExtractor(BaseWorker):
self.cached_images = dict()
# Where to save the downloaded images
self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data"))
logger.info(f"Images will be saved at `{self.image_folder}`.")
def read_training_related_information(self) -> None:
"""
Read from process information
- train_folder_id
- validation_folder_id
- test_folder_id (optional)
"""
logger.info("Retrieving information from process_information")
train_folder_id = self.process_information.get("train_folder_id")
assert train_folder_id, "A training folder id is necessary to use this worker"
self.training_folder_id = UUID(train_folder_id)
val_folder_id = self.process_information.get("validation_folder_id")
assert val_folder_id, "A validation folder id is necessary to use this worker"
self.validation_folder_id = UUID(val_folder_id)
test_folder_id = self.process_information.get("test_folder_id")
self.testing_folder_id: UUID | None = (
UUID(test_folder_id) if test_folder_id else None
)
self.images_folder = self.data_folder_path / "images"
self.images_folder.mkdir(parents=True)
logger.info(f"Images will be saved at `{self.images_folder}`.")
def configure_cache(self) -> None:
"""
Create an SQLite database compatible with base-worker cache and initialize it.
"""
self.use_cache = True
self.cache_path: Path = self.args.database or self.work_dir / "db.sqlite"
# Remove previous execution result if present
self.cache_path.unlink(missing_ok=True)
self.cache_path: Path = self.data_folder_path / "db.sqlite"
logger.info(f"Cached database will be saved at `{self.cache_path}`.")
init_cache_db(self.cache_path)
......@@ -126,7 +111,7 @@ class DatasetExtractor(BaseWorker):
# Find the latest that is in "done" state
exports: List[dict] = sorted(
list(filter(lambda exp: exp["state"] == "done", exports)),
key=operator.itemgetter("updated"),
key=itemgetter("updated"),
reverse=True,
)
assert (
......@@ -261,7 +246,7 @@ class DatasetExtractor(BaseWorker):
# Download image
logger.info("Downloading image")
download_image(url=build_image_url(element)).save(
self.image_folder / f"{element.image.id}.jpg"
self.images_folder / f"{element.image.id}.jpg"
)
# Insert image
logger.info("Inserting image")
......@@ -301,60 +286,49 @@ class DatasetExtractor(BaseWorker):
# Insert entities
self.insert_entities(transcriptions)
def process_split(self, split_name: str, split_id: UUID) -> None:
"""
Insert all elements under the given parent folder (all queries are recursive).
- `page` elements are linked to this folder (via parent_id foreign key)
- `page` element children are linked to their `page` parent (via parent_id foreign key)
"""
def process_split(self, split_name: str, elements: List[Element]) -> None:
logger.info(
f"Filling the Base-Worker cache with information from children under element ({split_id})"
f"Filling the cache with information from elements in the split {split_name}"
)
# Fill cache
# Retrieve parent and create parent
parent: Element = retrieve_element(split_id)
self.insert_element(parent)
# First list all pages
pages = list_children(split_id).join(Image).where(Element.type == "page")
nb_pages: int = pages.count()
for idx, page in enumerate(pages, start=1):
logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})")
nb_elements: int = len(elements)
for idx, element in enumerate(elements, start=1):
logger.info(f"Processing `{split_name}` element ({idx}/{nb_elements})")
# Insert page
self.insert_element(page, parent_id=split_id)
self.insert_element(element)
# List children
children = list_children(page.id)
children = list_children(element.id)
nb_children: int = children.count()
for child_idx, child in enumerate(children, start=1):
logger.info(f"Processing child ({child_idx}/{nb_children})")
# Insert child
self.insert_element(child, parent_id=page.id)
def run(self):
self.configure()
# Iterate over given split
for split_name, split_id in [
("Train", self.training_folder_id),
("Validation", self.validation_folder_id),
("Test", self.testing_folder_id),
]:
if not split_id:
continue
self.process_split(split_name, split_id)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path: Path = self.work_dir / "arkindex_data.zstd"
self.insert_element(child, parent_id=element.id)
def process_dataset(self, dataset: Dataset):
# Configure temporary storage for the dataset data (cache + images)
self.configure_storage()
# Iterate over given splits
for split_name, elements in self.list_dataset_elements_per_split(dataset):
casted_elements = list(map(_format_element, elements))
self.process_split(split_name, casted_elements)
# TAR + ZSTD the cache and the images folder, and store as task artifact
zstd_archive_path: Path = self.work_dir / f"{dataset.id}.zstd"
logger.info(f"Compressing the images to {zstd_archive_path}")
create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path)
create_tar_zst_archive(
source=self.data_folder_path, destination=zstd_archive_path
)
self.data_folder.cleanup()
def main():
DatasetExtractor(
description="Fill base-worker cache with information about dataset and extract images",
support_cache=True,
generator=True,
).run()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment