Newer
Older
from argparse import Namespace
from uuid import UUID, uuid4
from arkindex_export.models import Element
from arkindex_worker.cache import (
CachedClassification,
CachedDataset,
CachedDatasetElement,
CachedElement,
CachedEntity,
CachedImage,
CachedTranscription,
CachedTranscriptionEntity,
)
from worker_generic_training_dataset.from_sql import DatasetExtractorFromSQL
worker = DatasetExtractorFromSQL()
worker.args = Namespace(database=None)
worker.data_folder_path = tmp_path
worker.cached_images = dict()
# Where to save the downloaded images
worker.images_folder = tmp_path / "images"
worker.images_folder.mkdir(parents=True)
first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
# The dataset should already be saved in database when we call `process_split`
worker.cached_dataset = CachedDataset.create(
id=uuid4(),
name="My dataset",
state="complete",
sets=json.dumps(["train", "val", "test"]),
)
worker.process_split(
"train",
[
Element.get_by_id(first_page_id),
Element.get_by_id(second_page_id),
# Should have created 19 elements in total
assert CachedElement.select().count() == 19
# Should have created two pages at root
assert CachedElement.select().where(CachedElement.parent_id.is_null()).count() == 2
# Should have created 8 text_lines under first page
assert (
CachedElement.select().where(CachedElement.parent_id == first_page_id).count()
CachedElement.select().where(CachedElement.parent_id == second_page_id).count()
== 9
# Should have created one classification
assert CachedClassification.select().count() == 1
# Check classification
classif = CachedClassification.get_by_id("ea80bc43-6e96-45a1-b6d7-70f29b5871a6")
assert classif.element.id == first_page_id
assert classif.class_name == "Hello darkness"
assert classif.confidence == 1.0
assert classif.state == "validated"
assert classif.worker_run_id is None
# Should have created two images
assert CachedImage.select().count() == 2
first_image_id = UUID("80a84b30-1ae1-4c13-95d6-7d0d8ee16c51")
second_image_id = UUID("e3c755f2-0e1c-468e-ae4c-9206f0fd267a")
# Check images
for image_id, page_name in (
(first_image_id, "g06-018m"),
(second_image_id, "c02-026"),
):
page_image = CachedImage.get_by_id(image_id)
assert page_image.width == 2479
assert page_image.height == 3542
assert (
page_image.url
== f"https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2F{page_name}.png"
)
# Should have created 17 transcriptions
assert CachedTranscription.select().count() == 17
# Check transcription of first line on first page
transcription: CachedTranscription = CachedTranscription.get_by_id(
"144974c3-2477-4101-a486-6c276b0070aa"
assert transcription.element.id == UUID("6411b293-dee0-4002-ac82-ffc5cdcb499d")
assert transcription.text == "When the sailing season was past , he sent Pearl"
assert transcription.confidence == 1.0
assert transcription.orientation == "horizontal-lr"
assert transcription.worker_version_id is None
assert transcription.worker_run_id is None
# Should have created a transcription entity linked to transcription of first line of first page
assert CachedEntity.select().count() == 1
entity: CachedEntity = CachedEntity.get_by_id(
"e04b0323-0dda-4f76-b218-af40f7d40c84"
)
assert entity.name == "When the sailing season"
assert entity.type == "something something"
assert entity.metas == "{}"
assert entity.validated is False
assert entity.worker_run_id is None
# Checks on transcription entity
assert CachedTranscriptionEntity.select().count() == 1
tr_entity: CachedTranscriptionEntity = (
CachedTranscriptionEntity.select()
.where(
(
CachedTranscriptionEntity.transcription
== "144974c3-2477-4101-a486-6c276b0070aa"
)
& (CachedTranscriptionEntity.entity == entity)
)
.get()
)
assert tr_entity.offset == 0
assert tr_entity.length == 23
assert tr_entity.confidence == 1.0
assert tr_entity.worker_run_id is None
# Should have linked the page elements to the correct dataset & split
assert CachedDatasetElement.select().count() == 2
assert (
CachedDatasetElement.select()
.where(
CachedDatasetElement.dataset == worker.cached_dataset,
CachedDatasetElement.set_name == "train",
)
.count()
# Full structure of the archive
assert sorted(tmp_path.rglob("*")) == [
tmp_path / "db.sqlite",
tmp_path / "images",
tmp_path / "images" / f"{first_image_id}.jpg",
tmp_path / "images" / f"{second_image_id}.jpg",
]