From 6f7e9b571946e609ff8525eb95b485fe96e0b646 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Wed, 10 May 2023 15:01:57 +0200 Subject: [PATCH] moar tests --- tests/conftest.py | 35 +++++- tests/test_worker.py | 104 +++++++++++++---- worker_generic_training_dataset/db.py | 17 +-- worker_generic_training_dataset/utils.py | 24 ++-- worker_generic_training_dataset/worker.py | 132 ++++++++++++---------- 5 files changed, 202 insertions(+), 110 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b0da596..98a69b3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ from arkindex.mock import MockApiClient from arkindex_export import open_database from arkindex_worker.worker.base import BaseWorker -DATA_PATH = Path(__file__).parent / "data" +DATA_PATH: Path = Path(__file__).parent / "data" @pytest.fixture(autouse=True) @@ -22,8 +22,6 @@ def setup_environment(responses, monkeypatch): "https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json", ) responses.add_passthru(schema_url) - # To allow image download - responses.add_passthru("https://europe-gamma.iiif.teklia.com/iiif/2") # Set schema url in environment os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url @@ -35,5 +33,34 @@ def setup_environment(responses, monkeypatch): @pytest.fixture(scope="session", autouse=True) -def arkindex_db(): +def arkindex_db() -> None: open_database(DATA_PATH / "arkindex_export.sqlite") + + +@pytest.fixture +def page_1_image() -> bytes: + with (DATA_PATH / "sample_image_1.jpg").open("rb") as image: + return image.read() + + +@pytest.fixture +def page_2_image() -> bytes: + with (DATA_PATH / "sample_image_2.jpg").open("rb") as image: + return image.read() + + +@pytest.fixture +def downloaded_images(responses, page_1_image, page_2_image) -> None: + # Mock image download call + # Page 1 + responses.add( + responses.GET, + "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2Fg06-018m.png/0,0,2479,3542/full/0/default.jpg", + body=page_1_image, + ) + # Page 2 + responses.add( + responses.GET, + "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2Fc02-026.png/0,0,2479,3542/full/0/default.jpg", + body=page_2_image, + ) diff --git a/tests/test_worker.py b/tests/test_worker.py index cf638a8..9a2aeac 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from argparse import Namespace +from uuid import UUID from arkindex_worker.cache import ( CachedClassification, @@ -13,14 +14,14 @@ from arkindex_worker.cache import ( from worker_generic_training_dataset.worker import DatasetExtractor -def test_process_split(tmp_path): +def test_process_split(tmp_path, downloaded_images): # Parent is train folder - parent_id = "a0c4522d-2d80-4766-a01c-b9d686f41f6a" + parent_id: UUID = UUID("a0c4522d-2d80-4766-a01c-b9d686f41f6a") worker = DatasetExtractor() - # Mock important configuration steps - worker.args = Namespace(dev=False) - worker.initialize_database() + # Parse some arguments + worker.args = Namespace(dev=False, database=None) + worker.configure_cache() worker.cached_images = dict() # Where to save the downloaded images @@ -28,44 +29,99 @@ def test_process_split(tmp_path): worker.process_split("train", parent_id) + # Should have created 20 elements in total + assert CachedElement.select().count() == 20 + # Should have created two pages under root folder assert ( CachedElement.select().where(CachedElement.parent_id == parent_id).count() == 2 ) + first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c") + second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f") + # Should have created 8 text_lines under first page assert ( - CachedElement.select() - .where(CachedElement.parent_id == "e26e6803-18da-4768-be30-a0a68132107c") - .count() + CachedElement.select().where(CachedElement.parent_id == first_page_id).count() == 8 ) - - # Should have created one classification linked to first page + # Should have created 9 text_lines under second page assert ( - CachedClassification.select() - .where(CachedClassification.element == "e26e6803-18da-4768-be30-a0a68132107c") - .count() - == 1 + CachedElement.select().where(CachedElement.parent_id == second_page_id).count() + == 9 ) + # Should have created one classification + assert CachedClassification.select().count() == 1 + # Check classification + classif = CachedClassification.get_by_id("ea80bc43-6e96-45a1-b6d7-70f29b5871a6") + assert classif.element.id == first_page_id + assert classif.class_name == "Hello darkness" + assert classif.confidence == 1.0 + assert classif.state == "validated" + assert classif.worker_run_id is None + # Should have created two images assert CachedImage.select().count() == 2 + first_image_id = UUID("80a84b30-1ae1-4c13-95d6-7d0d8ee16c51") + second_image_id = UUID("e3c755f2-0e1c-468e-ae4c-9206f0fd267a") + # Check images + for image_id, page_name in ( + (first_image_id, "g06-018m"), + (second_image_id, "c02-026"), + ): + page_image = CachedImage.get_by_id(image_id) + assert page_image.width == 2479 + assert page_image.height == 3542 + assert ( + page_image.url + == f"https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2F{page_name}.png" + ) + assert sorted(tmp_path.rglob("*")) == [ - tmp_path / "80a84b30-1ae1-4c13-95d6-7d0d8ee16c51.jpg", - tmp_path / "e3c755f2-0e1c-468e-ae4c-9206f0fd267a.jpg", + tmp_path / f"{first_image_id}.jpg", + tmp_path / f"{second_image_id}.jpg", ] - # Should have created a transcription linked to first line of first page - assert ( - CachedTranscription.select() - .where(CachedTranscription.element == "6411b293-dee0-4002-ac82-ffc5cdcb499d") - .count() - == 1 + # Should have created 17 transcriptions + assert CachedTranscription.select().count() == 17 + # Check transcription of first line on first page + transcription: CachedTranscription = CachedTranscription.get_by_id( + "144974c3-2477-4101-a486-6c276b0070aa" ) + assert transcription.element.id == UUID("6411b293-dee0-4002-ac82-ffc5cdcb499d") + assert transcription.text == "When the sailing season was past , he sent Pearl" + assert transcription.confidence == 1.0 + assert transcription.orientation == "horizontal-lr" + assert transcription.worker_version_id is None + assert transcription.worker_run_id is None # Should have created a transcription entity linked to transcription of first line of first page + # Checks on entity assert CachedEntity.select().count() == 1 - assert CachedTranscriptionEntity.select().count() == 1 + entity: CachedEntity = CachedEntity.get_by_id( + "e04b0323-0dda-4f76-b218-af40f7d40c84" + ) + assert entity.name == "When the sailing season" + assert entity.type == "something something" + assert entity.metas == "{}" + assert entity.validated is False + assert entity.worker_run_id is None - assert CachedEntity.get_by_id("e04b0323-0dda-4f76-b218-af40f7d40c84") + # Checks on transcription entity + assert CachedTranscriptionEntity.select().count() == 1 + tr_entity: CachedTranscriptionEntity = ( + CachedTranscriptionEntity.select() + .where( + ( + CachedTranscriptionEntity.transcription + == "144974c3-2477-4101-a486-6c276b0070aa" + ) + & (CachedTranscriptionEntity.entity == entity) + ) + .get() + ) + assert tr_entity.offset == 0 + assert tr_entity.length == 23 + assert tr_entity.confidence == 1.0 + assert tr_entity.worker_run_id is None diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py index bb18f7e..8ae47d6 100644 --- a/worker_generic_training_dataset/db.py +++ b/worker_generic_training_dataset/db.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +from uuid import UUID + from arkindex_export import Classification from arkindex_export.models import ( Element, @@ -8,25 +10,24 @@ from arkindex_export.models import ( Transcription, TranscriptionEntity, ) -from arkindex_worker.cache import CachedElement, CachedTranscription -def retrieve_element(element_id: str): +def retrieve_element(element_id: UUID) -> Element: return Element.get_by_id(element_id) -def list_classifications(element_id: str): - return Classification.select().where(Classification.element_id == element_id) +def list_classifications(element_id: UUID): + return Classification.select().where(Classification.element == element_id) -def list_transcriptions(element: CachedElement): - return Transcription.select().where(Transcription.element_id == element.id) +def list_transcriptions(element_id: UUID): + return Transcription.select().where(Transcription.element == element_id) -def list_transcription_entities(transcription: CachedTranscription): +def list_transcription_entities(transcription_id: UUID): return ( TranscriptionEntity.select() - .where(TranscriptionEntity.transcription_id == transcription.id) + .where(TranscriptionEntity.transcription == transcription_id) .join(Entity, on=TranscriptionEntity.entity) .join(EntityType, on=Entity.type) ) diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py index 76c2bca..34d9760 100644 --- a/worker_generic_training_dataset/utils.py +++ b/worker_generic_training_dataset/utils.py @@ -1,23 +1,21 @@ # -*- coding: utf-8 -*- -import ast +import json import logging +from logging import Logger from urllib.parse import urljoin -logger = logging.getLogger(__name__) +from arkindex_worker.image import BoundingBox, polygon_bounding_box +logger: Logger = logging.getLogger(__name__) -def bounding_box(polygon: list): - """ - Returns a 4-tuple (x, y, width, height) for the bounding box of a Polygon (list of points) - """ - all_x, all_y = zip(*polygon) - x, y = min(all_x), min(all_y) - width, height = max(all_x) - x, max(all_y) - y - return int(x), int(y), int(width), int(height) - -def build_image_url(element): - x, y, width, height = bounding_box(json.loads(element.polygon)) +def build_image_url(element) -> str: + bbox: BoundingBox = polygon_bounding_box(json.loads(element.polygon)) + x: int + y: int + width: int + height: int + x, y, width, height = bbox return urljoin( element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg" ) diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 88a16a6..7c20093 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -2,7 +2,9 @@ import logging import operator import tempfile +from argparse import Namespace from pathlib import Path +from tempfile import _TemporaryFileWrapper from typing import List, Optional from uuid import UUID @@ -32,15 +34,15 @@ from worker_generic_training_dataset.db import ( ) from worker_generic_training_dataset.utils import build_image_url -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) BULK_BATCH_SIZE = 50 DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr" class DatasetExtractor(BaseWorker): - def configure(self): - self.args = self.parser.parse_args() + def configure(self) -> None: + self.args: Namespace = self.parser.parse_args() if self.is_read_only: super().configure_for_developers() else: @@ -57,7 +59,7 @@ class DatasetExtractor(BaseWorker): self.download_latest_export() # Initialize db that will be written - self.initialize_database() + self.configure_cache() # CachedImage downloaded and created in DB self.cached_images = dict() @@ -66,7 +68,7 @@ class DatasetExtractor(BaseWorker): self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data")) logger.info(f"Images will be saved at `{self.image_folder}`.") - def read_training_related_information(self): + def read_training_related_information(self) -> None: """ Read from process information - train_folder_id @@ -84,23 +86,26 @@ class DatasetExtractor(BaseWorker): self.validation_folder_id = UUID(val_folder_id) test_folder_id = self.config.get("test_folder_id") - self.testing_folder_id = UUID(test_folder_id) if test_folder_id else None + self.testing_folder_id: UUID | None = ( + UUID(test_folder_id) if test_folder_id else None + ) - def initialize_database(self): + def configure_cache(self) -> None: """ Create an SQLite database compatible with base-worker cache and initialize it. """ - database_path = self.work_dir / "db.sqlite" + self.use_cache = True + self.cache_path: Path = self.args.database or self.work_dir / "db.sqlite" # Remove previous execution result if present - database_path.unlink(missing_ok=True) + self.cache_path.unlink(missing_ok=True) - init_cache_db(database_path) + init_cache_db(self.cache_path) create_version_table() create_tables() - def download_latest_export(self): + def download_latest_export(self) -> None: """ Download the latest export of the current corpus. Export must be in `"done"` state. @@ -119,7 +124,7 @@ class DatasetExtractor(BaseWorker): raise e # Find the latest that is in "done" state - exports = sorted( + exports: List[dict] = sorted( list(filter(lambda exp: exp["state"] == "done", exports)), key=operator.itemgetter("updated"), reverse=True, @@ -128,11 +133,11 @@ class DatasetExtractor(BaseWorker): len(exports) > 0 ), f"No available exports found for the corpus {self.corpus_id}." - # Download latest it in a tmpfile + # Download latest export try: - export_id = exports[0]["id"] + export_id: str = exports[0]["id"] logger.info(f"Downloading export ({export_id})...") - self.export = self.api_client.request( + self.export: _TemporaryFileWrapper = self.api_client.request( "DownloadExport", id=export_id, ) @@ -146,7 +151,7 @@ class DatasetExtractor(BaseWorker): def insert_classifications(self, element: CachedElement) -> None: logger.info("Listing classifications") - classifications = [ + classifications: list[CachedClassification] = [ CachedClassification( id=classification.id, element=element, @@ -168,7 +173,7 @@ class DatasetExtractor(BaseWorker): self, element: CachedElement ) -> List[CachedTranscription]: logger.info("Listing transcriptions") - transcriptions = [ + transcriptions: list[CachedTranscription] = [ CachedTranscription( id=transcription.id, element=element, @@ -180,7 +185,7 @@ class DatasetExtractor(BaseWorker): if transcription.worker_version else None, ) - for transcription in list_transcriptions(element) + for transcription in list_transcriptions(element.id) ] if transcriptions: logger.info(f"Inserting {len(transcriptions)} transcription(s)") @@ -193,9 +198,10 @@ class DatasetExtractor(BaseWorker): def insert_entities(self, transcriptions: List[CachedTranscription]) -> None: logger.info("Listing entities") - extracted_entities = [] + entities: List[CachedEntity] = [] + transcription_entities: List[CachedTranscriptionEntity] = [] for transcription in transcriptions: - for transcription_entity in list_transcription_entities(transcription): + for transcription_entity in list_transcription_entities(transcription.id): entity = CachedEntity( id=transcription_entity.entity.id, type=transcription_entity.entity.type.name, @@ -203,17 +209,15 @@ class DatasetExtractor(BaseWorker): validated=transcription_entity.entity.validated, metas=transcription_entity.entity.metas, ) - extracted_entities.append( - ( - entity, - CachedTranscriptionEntity( - id=transcription_entity.id, - transcription=transcription, - entity=entity, - offset=transcription_entity.offset, - length=transcription_entity.length, - confidence=transcription_entity.confidence, - ), + entities.append(entity) + transcription_entities.append( + CachedTranscriptionEntity( + id=transcription_entity.id, + transcription=transcription, + entity=entity, + offset=transcription_entity.offset, + length=transcription_entity.length, + confidence=transcription_entity.confidence, ) ) if entities: @@ -236,17 +240,19 @@ class DatasetExtractor(BaseWorker): batch_size=BULK_BATCH_SIZE, ) - def insert_element(self, element: Element, parent_id: Optional[str] = None): + def insert_element( + self, element: Element, parent_id: Optional[UUID] = None + ) -> None: """ - Insert the given element's children in the cache database. - Their image will also be saved to disk, if they weren't already. + Insert the given element in the cache database. + Its image will also be saved to disk, if it wasn't already. The insertion of an element includes: - its classifications - its transcriptions - its transcriptions' entities (both Entity and TranscriptionEntity) - :param element: Element to insert. All its children will be inserted as well. + :param element: Element to insert. :param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements. """ logger.info(f"Processing element ({element.id})") @@ -259,41 +265,45 @@ class DatasetExtractor(BaseWorker): # Insert image logger.info("Inserting image") # Store images in case some other elements use it as well - self.cached_images[element.image_id] = CachedImage.create( - id=element.image.id, - width=element.image.width, - height=element.image.height, - url=element.image.url, - ) + with cache_database.atomic(): + self.cached_images[element.image.id] = CachedImage.create( + id=element.image.id, + width=element.image.width, + height=element.image.height, + url=element.image.url, + ) # Insert element logger.info("Inserting element") - cached_element = CachedElement.create( - id=element.id, - parent_id=parent_id, - type=element.type, - image=self.cached_images[element.image.id] if element.image else None, - polygon=element.polygon, - rotation_angle=element.rotation_angle, - mirrored=element.mirrored, - worker_version_id=element.worker_version.id - if element.worker_version - else None, - confidence=element.confidence, - ) + with cache_database.atomic(): + cached_element: CachedElement = CachedElement.create( + id=element.id, + parent_id=parent_id, + type=element.type, + image=self.cached_images[element.image.id] if element.image else None, + polygon=element.polygon, + rotation_angle=element.rotation_angle, + mirrored=element.mirrored, + worker_version_id=element.worker_version.id + if element.worker_version + else None, + confidence=element.confidence, + ) # Insert classifications self.insert_classifications(cached_element) # Insert transcriptions - transcriptions = self.insert_transcriptions(cached_element) + transcriptions: List[CachedTranscription] = self.insert_transcriptions( + cached_element + ) # Insert entities self.insert_entities(transcriptions) - def process_split(self, split_name: str, split_id: UUID): + def process_split(self, split_name: str, split_id: UUID) -> None: """ - Insert all elements under the given parent folder. + Insert all elements under the given parent folder (all queries are recursive). - `page` elements are linked to this folder (via parent_id foreign key) - `page` element children are linked to their `page` parent (via parent_id foreign key) """ @@ -302,12 +312,12 @@ class DatasetExtractor(BaseWorker): ) # Fill cache # Retrieve parent and create parent - parent = retrieve_element(split_id) + parent: Element = retrieve_element(split_id) self.insert_element(parent) # First list all pages pages = list_children(split_id).join(Image).where(Element.type == "page") - nb_pages = pages.count() + nb_pages: int = pages.count() for idx, page in enumerate(pages, start=1): logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})") @@ -316,7 +326,7 @@ class DatasetExtractor(BaseWorker): # List children children = list_children(page.id) - nb_children = children.count() + nb_children: int = children.count() for child_idx, child in enumerate(children, start=1): logger.info(f"Processing child ({child_idx}/{nb_children})") # Insert child @@ -336,7 +346,7 @@ class DatasetExtractor(BaseWorker): self.process_split(split_name, split_id) # TAR + ZSTD Image folder and store as task artifact - zstd_archive_path = self.work_dir / "arkindex_data.zstd" + zstd_archive_path: Path = self.work_dir / "arkindex_data.zstd" logger.info(f"Compressing the images to {zstd_archive_path}") create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path) -- GitLab