Skip to content
Snippets Groups Projects

New DatasetExtractor using a DatasetWorker

Merged Eva Bardou requested to merge dataset-worker into main
All threads resolved!
2 files
+ 16
13
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 25
18
@@ -11,34 +11,38 @@ from arkindex_worker.cache import (
CachedTranscription,
CachedTranscriptionEntity,
)
from worker_generic_training_dataset.db import retrieve_element
from worker_generic_training_dataset.worker import DatasetExtractor
def test_process_split(tmp_path, downloaded_images):
# Parent is train folder
parent_id: UUID = UUID("a0c4522d-2d80-4766-a01c-b9d686f41f6a")
worker = DatasetExtractor()
# Parse some arguments
worker.args = Namespace(database=None)
worker.data_folder_path = tmp_path
worker.configure_cache()
worker.cached_images = dict()
# Where to save the downloaded images
worker.image_folder = tmp_path
worker.process_split("train", parent_id)
worker.images_folder = tmp_path / "images"
worker.images_folder.mkdir(parents=True)
# Should have created 20 elements in total
assert CachedElement.select().count() == 20
first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
# Should have created two pages under root folder
assert (
CachedElement.select().where(CachedElement.parent_id == parent_id).count() == 2
worker.process_split(
"train",
[
retrieve_element(first_page_id),
retrieve_element(second_page_id),
],
)
first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
# Should have created 20 elements in total
assert CachedElement.select().count() == 19
# Should have created two pages at root
assert CachedElement.select().where(CachedElement.parent_id.is_null()).count() == 2
# Should have created 8 text_lines under first page
assert (
@@ -78,11 +82,6 @@ def test_process_split(tmp_path, downloaded_images):
== f"https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2F{page_name}.png"
)
assert sorted(tmp_path.rglob("*")) == [
tmp_path / f"{first_image_id}.jpg",
tmp_path / f"{second_image_id}.jpg",
]
# Should have created 17 transcriptions
assert CachedTranscription.select().count() == 17
# Check transcription of first line on first page
@@ -125,3 +124,11 @@ def test_process_split(tmp_path, downloaded_images):
assert tr_entity.length == 23
assert tr_entity.confidence == 1.0
assert tr_entity.worker_run_id is None
# Full structure of the archive
assert sorted(tmp_path.rglob("*")) == [
tmp_path / "db.sqlite",
tmp_path / "images",
tmp_path / "images" / f"{first_image_id}.jpg",
tmp_path / "images" / f"{second_image_id}.jpg",
]
Loading