Skip to content
Snippets Groups Projects

Implement worker

Merged Yoann Schneider requested to merge implem into main
1 file
+ 0
8
Compare changes
  • Side-by-side
  • Inline
+ 122
8
# -*- coding: utf-8 -*-
import importlib
from argparse import Namespace
from uuid import UUID
def test_dummy():
assert True
from arkindex_worker.cache import (
CachedClassification,
CachedElement,
CachedEntity,
CachedImage,
CachedTranscription,
CachedTranscriptionEntity,
)
from worker_generic_training_dataset.worker import DatasetExtractor
def test_import():
"""Import our newly created module, through importlib to avoid parsing issues"""
worker = importlib.import_module("worker_generic_training_dataset.worker")
assert hasattr(worker, "Demo")
assert hasattr(worker.Demo, "process_element")
def test_process_split(tmp_path, downloaded_images):
# Parent is train folder
parent_id: UUID = UUID("a0c4522d-2d80-4766-a01c-b9d686f41f6a")
worker = DatasetExtractor()
# Parse some arguments
worker.args = Namespace(database=None)
worker.configure_cache()
worker.cached_images = dict()
# Where to save the downloaded images
worker.image_folder = tmp_path
worker.process_split("train", parent_id)
# Should have created 20 elements in total
assert CachedElement.select().count() == 20
# Should have created two pages under root folder
assert (
CachedElement.select().where(CachedElement.parent_id == parent_id).count() == 2
)
first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
# Should have created 8 text_lines under first page
assert (
CachedElement.select().where(CachedElement.parent_id == first_page_id).count()
== 8
)
# Should have created 9 text_lines under second page
assert (
CachedElement.select().where(CachedElement.parent_id == second_page_id).count()
== 9
)
# Should have created one classification
assert CachedClassification.select().count() == 1
# Check classification
classif = CachedClassification.get_by_id("ea80bc43-6e96-45a1-b6d7-70f29b5871a6")
assert classif.element.id == first_page_id
assert classif.class_name == "Hello darkness"
assert classif.confidence == 1.0
assert classif.state == "validated"
assert classif.worker_run_id is None
# Should have created two images
assert CachedImage.select().count() == 2
first_image_id = UUID("80a84b30-1ae1-4c13-95d6-7d0d8ee16c51")
second_image_id = UUID("e3c755f2-0e1c-468e-ae4c-9206f0fd267a")
# Check images
for image_id, page_name in (
(first_image_id, "g06-018m"),
(second_image_id, "c02-026"),
):
page_image = CachedImage.get_by_id(image_id)
assert page_image.width == 2479
assert page_image.height == 3542
assert (
page_image.url
== f"https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2F{page_name}.png"
)
assert sorted(tmp_path.rglob("*")) == [
tmp_path / f"{first_image_id}.jpg",
tmp_path / f"{second_image_id}.jpg",
]
# Should have created 17 transcriptions
assert CachedTranscription.select().count() == 17
# Check transcription of first line on first page
transcription: CachedTranscription = CachedTranscription.get_by_id(
"144974c3-2477-4101-a486-6c276b0070aa"
)
assert transcription.element.id == UUID("6411b293-dee0-4002-ac82-ffc5cdcb499d")
assert transcription.text == "When the sailing season was past , he sent Pearl"
assert transcription.confidence == 1.0
assert transcription.orientation == "horizontal-lr"
assert transcription.worker_version_id is None
assert transcription.worker_run_id is None
# Should have created a transcription entity linked to transcription of first line of first page
# Checks on entity
assert CachedEntity.select().count() == 1
entity: CachedEntity = CachedEntity.get_by_id(
"e04b0323-0dda-4f76-b218-af40f7d40c84"
)
assert entity.name == "When the sailing season"
assert entity.type == "something something"
assert entity.metas == "{}"
assert entity.validated is False
assert entity.worker_run_id is None
# Checks on transcription entity
assert CachedTranscriptionEntity.select().count() == 1
tr_entity: CachedTranscriptionEntity = (
CachedTranscriptionEntity.select()
.where(
(
CachedTranscriptionEntity.transcription
== "144974c3-2477-4101-a486-6c276b0070aa"
)
& (CachedTranscriptionEntity.entity == entity)
)
.get()
)
assert tr_entity.offset == 0
assert tr_entity.length == 23
assert tr_entity.confidence == 1.0
assert tr_entity.worker_run_id is None
Loading