Skip to content
Snippets Groups Projects
test_from_sql.py 5.24 KiB
Newer Older
Yoann Schneider's avatar
Yoann Schneider committed
# -*- coding: utf-8 -*-

from argparse import Namespace
from uuid import UUID, uuid4
Yoann Schneider's avatar
Yoann Schneider committed

from arkindex_export.models import Element
from arkindex_worker.cache import (
    CachedClassification,
    CachedDataset,
    CachedDatasetElement,
    CachedElement,
    CachedEntity,
    CachedImage,
    CachedTranscription,
    CachedTranscriptionEntity,
)
from worker_generic_training_dataset.from_sql import DatasetExtractorFromSQL
Yoann Schneider's avatar
Yoann Schneider committed
def test_process_split(tmp_path, downloaded_images):
    worker = DatasetExtractorFromSQL()
Yoann Schneider's avatar
Yoann Schneider committed
    # Parse some arguments
    worker.args = Namespace(database=None)
    worker.data_folder_path = tmp_path
Yoann Schneider's avatar
Yoann Schneider committed
    worker.configure_cache()
    worker.cached_images = dict()

    # Where to save the downloaded images
    worker.images_folder = tmp_path / "images"
    worker.images_folder.mkdir(parents=True)
    first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
    second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
Yoann Schneider's avatar
Yoann Schneider committed

    # The dataset should already be saved in database when we call `process_split`
    worker.cached_dataset = CachedDataset.create(
        id=uuid4(),
        name="My dataset",
        state="complete",
        sets=json.dumps(["train", "val", "test"]),
    )

    worker.process_split(
        "train",
        [
            Element.get_by_id(first_page_id),
            Element.get_by_id(second_page_id),
    # Should have created 19 elements in total
    assert CachedElement.select().count() == 19

    # Should have created two pages at root
    assert CachedElement.select().where(CachedElement.parent_id.is_null()).count() == 2
Yoann Schneider's avatar
Yoann Schneider committed

    # Should have created 8 text_lines under first page
    assert (
Yoann Schneider's avatar
Yoann Schneider committed
        CachedElement.select().where(CachedElement.parent_id == first_page_id).count()
Yoann Schneider's avatar
Yoann Schneider committed
    # Should have created 9 text_lines under second page
Yoann Schneider's avatar
Yoann Schneider committed
        CachedElement.select().where(CachedElement.parent_id == second_page_id).count()
        == 9
Yoann Schneider's avatar
Yoann Schneider committed
    # Should have created one classification
    assert CachedClassification.select().count() == 1
    # Check classification
    classif = CachedClassification.get_by_id("ea80bc43-6e96-45a1-b6d7-70f29b5871a6")
    assert classif.element.id == first_page_id
    assert classif.class_name == "Hello darkness"
    assert classif.confidence == 1.0
    assert classif.state == "validated"
    assert classif.worker_run_id is None

    # Should have created two images
    assert CachedImage.select().count() == 2
Yoann Schneider's avatar
Yoann Schneider committed
    first_image_id = UUID("80a84b30-1ae1-4c13-95d6-7d0d8ee16c51")
    second_image_id = UUID("e3c755f2-0e1c-468e-ae4c-9206f0fd267a")
    # Check images
    for image_id, page_name in (
        (first_image_id, "g06-018m"),
        (second_image_id, "c02-026"),
    ):
        page_image = CachedImage.get_by_id(image_id)
        assert page_image.width == 2479
        assert page_image.height == 3542
        assert (
            page_image.url
            == f"https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2F{page_name}.png"
        )

    # Should have created 17 transcriptions
    assert CachedTranscription.select().count() == 17
    # Check transcription of first line on first page
    transcription: CachedTranscription = CachedTranscription.get_by_id(
        "144974c3-2477-4101-a486-6c276b0070aa"
Yoann Schneider's avatar
Yoann Schneider committed
    assert transcription.element.id == UUID("6411b293-dee0-4002-ac82-ffc5cdcb499d")
    assert transcription.text == "When the sailing season was past , he sent Pearl"
    assert transcription.confidence == 1.0
    assert transcription.orientation == "horizontal-lr"
    assert transcription.worker_version_id is None
    assert transcription.worker_run_id is None

    # Should have created a transcription entity linked to transcription of first line of first page
Yoann Schneider's avatar
Yoann Schneider committed
    # Checks on entity
    assert CachedEntity.select().count() == 1
Yoann Schneider's avatar
Yoann Schneider committed
    entity: CachedEntity = CachedEntity.get_by_id(
        "e04b0323-0dda-4f76-b218-af40f7d40c84"
    )
    assert entity.name == "When the sailing season"
    assert entity.type == "something something"
    assert entity.metas == "{}"
    assert entity.validated is False
    assert entity.worker_run_id is None
Yoann Schneider's avatar
Yoann Schneider committed
    # Checks on transcription entity
    assert CachedTranscriptionEntity.select().count() == 1
    tr_entity: CachedTranscriptionEntity = (
        CachedTranscriptionEntity.select()
        .where(
            (
                CachedTranscriptionEntity.transcription
                == "144974c3-2477-4101-a486-6c276b0070aa"
            )
            & (CachedTranscriptionEntity.entity == entity)
        )
        .get()
    )
    assert tr_entity.offset == 0
    assert tr_entity.length == 23
    assert tr_entity.confidence == 1.0
    assert tr_entity.worker_run_id is None
    # Should have linked the page elements to the correct dataset & split
    assert CachedDatasetElement.select().count() == 2
    assert (
        CachedDatasetElement.select()
        .where(
            CachedDatasetElement.dataset == worker.cached_dataset,
            CachedDatasetElement.set_name == "train",
        )
        .count()
    # Full structure of the archive
    assert sorted(tmp_path.rglob("*")) == [
        tmp_path / "db.sqlite",
        tmp_path / "images",
        tmp_path / "images" / f"{first_image_id}.jpg",
        tmp_path / "images" / f"{second_image_id}.jpg",
    ]