# -*- coding: utf-8 -*- import json from argparse import Namespace from uuid import UUID, uuid4 from arkindex_export.models import Element from arkindex_worker.cache import ( CachedClassification, CachedDataset, CachedDatasetElement, CachedElement, CachedEntity, CachedImage, CachedTranscription, CachedTranscriptionEntity, ) from worker_generic_training_dataset.from_sql import DatasetExtractorFromSQL def test_process_split(tmp_path, downloaded_images): worker = DatasetExtractorFromSQL() # Parse some arguments worker.args = Namespace(database=None) worker.data_folder_path = tmp_path worker.configure_cache() worker.cached_images = dict() # Where to save the downloaded images worker.images_folder = tmp_path / "images" worker.images_folder.mkdir(parents=True) first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c") second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f") # The dataset should already be saved in database when we call `process_split` worker.cached_dataset = CachedDataset.create( id=uuid4(), name="My dataset", state="complete", sets=json.dumps(["train", "val", "test"]), ) worker.process_split( "train", [ Element.get_by_id(first_page_id), Element.get_by_id(second_page_id), ], ) # Should have created 19 elements in total assert CachedElement.select().count() == 19 # Should have created two pages at root assert CachedElement.select().where(CachedElement.parent_id.is_null()).count() == 2 # Should have created 8 text_lines under first page assert ( CachedElement.select().where(CachedElement.parent_id == first_page_id).count() == 8 ) # Should have created 9 text_lines under second page assert ( CachedElement.select().where(CachedElement.parent_id == second_page_id).count() == 9 ) # Should have created one classification assert CachedClassification.select().count() == 1 # Check classification classif = CachedClassification.get_by_id("ea80bc43-6e96-45a1-b6d7-70f29b5871a6") assert classif.element.id == first_page_id assert classif.class_name == "Hello darkness" assert classif.confidence == 1.0 assert classif.state == "validated" assert classif.worker_run_id is None # Should have created two images assert CachedImage.select().count() == 2 first_image_id = UUID("80a84b30-1ae1-4c13-95d6-7d0d8ee16c51") second_image_id = UUID("e3c755f2-0e1c-468e-ae4c-9206f0fd267a") # Check images for image_id, page_name in ( (first_image_id, "g06-018m"), (second_image_id, "c02-026"), ): page_image = CachedImage.get_by_id(image_id) assert page_image.width == 2479 assert page_image.height == 3542 assert ( page_image.url == f"https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2F{page_name}.png" ) # Should have created 17 transcriptions assert CachedTranscription.select().count() == 17 # Check transcription of first line on first page transcription: CachedTranscription = CachedTranscription.get_by_id( "144974c3-2477-4101-a486-6c276b0070aa" ) assert transcription.element.id == UUID("6411b293-dee0-4002-ac82-ffc5cdcb499d") assert transcription.text == "When the sailing season was past , he sent Pearl" assert transcription.confidence == 1.0 assert transcription.orientation == "horizontal-lr" assert transcription.worker_version_id is None assert transcription.worker_run_id is None # Should have created a transcription entity linked to transcription of first line of first page # Checks on entity assert CachedEntity.select().count() == 1 entity: CachedEntity = CachedEntity.get_by_id( "e04b0323-0dda-4f76-b218-af40f7d40c84" ) assert entity.name == "When the sailing season" assert entity.type == "something something" assert entity.metas == "{}" assert entity.validated is False assert entity.worker_run_id is None # Checks on transcription entity assert CachedTranscriptionEntity.select().count() == 1 tr_entity: CachedTranscriptionEntity = ( CachedTranscriptionEntity.select() .where( ( CachedTranscriptionEntity.transcription == "144974c3-2477-4101-a486-6c276b0070aa" ) & (CachedTranscriptionEntity.entity == entity) ) .get() ) assert tr_entity.offset == 0 assert tr_entity.length == 23 assert tr_entity.confidence == 1.0 assert tr_entity.worker_run_id is None # Should have linked the page elements to the correct dataset & split assert CachedDatasetElement.select().count() == 2 assert ( CachedDatasetElement.select() .where( CachedDatasetElement.dataset == worker.cached_dataset, CachedDatasetElement.set_name == "train", ) .count() == 2 ) # Full structure of the archive assert sorted(tmp_path.rglob("*")) == [ tmp_path / "db.sqlite", tmp_path / "images", tmp_path / "images" / f"{first_image_id}.jpg", tmp_path / "images" / f"{second_image_id}.jpg", ]