From 6f7e9b571946e609ff8525eb95b485fe96e0b646 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Wed, 10 May 2023 15:01:57 +0200
Subject: [PATCH] moar tests

---
 tests/conftest.py                         |  35 +++++-
 tests/test_worker.py                      | 104 +++++++++++++----
 worker_generic_training_dataset/db.py     |  17 +--
 worker_generic_training_dataset/utils.py  |  24 ++--
 worker_generic_training_dataset/worker.py | 132 ++++++++++++----------
 5 files changed, 202 insertions(+), 110 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index b0da596..98a69b3 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -8,7 +8,7 @@ from arkindex.mock import MockApiClient
 from arkindex_export import open_database
 from arkindex_worker.worker.base import BaseWorker
 
-DATA_PATH = Path(__file__).parent / "data"
+DATA_PATH: Path = Path(__file__).parent / "data"
 
 
 @pytest.fixture(autouse=True)
@@ -22,8 +22,6 @@ def setup_environment(responses, monkeypatch):
         "https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json",
     )
     responses.add_passthru(schema_url)
-    # To allow image download
-    responses.add_passthru("https://europe-gamma.iiif.teklia.com/iiif/2")
 
     # Set schema url in environment
     os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url
@@ -35,5 +33,34 @@ def setup_environment(responses, monkeypatch):
 
 
 @pytest.fixture(scope="session", autouse=True)
-def arkindex_db():
+def arkindex_db() -> None:
     open_database(DATA_PATH / "arkindex_export.sqlite")
+
+
+@pytest.fixture
+def page_1_image() -> bytes:
+    with (DATA_PATH / "sample_image_1.jpg").open("rb") as image:
+        return image.read()
+
+
+@pytest.fixture
+def page_2_image() -> bytes:
+    with (DATA_PATH / "sample_image_2.jpg").open("rb") as image:
+        return image.read()
+
+
+@pytest.fixture
+def downloaded_images(responses, page_1_image, page_2_image) -> None:
+    # Mock image download call
+    # Page 1
+    responses.add(
+        responses.GET,
+        "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2Fg06-018m.png/0,0,2479,3542/full/0/default.jpg",
+        body=page_1_image,
+    )
+    # Page 2
+    responses.add(
+        responses.GET,
+        "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2Fc02-026.png/0,0,2479,3542/full/0/default.jpg",
+        body=page_2_image,
+    )
diff --git a/tests/test_worker.py b/tests/test_worker.py
index cf638a8..9a2aeac 100644
--- a/tests/test_worker.py
+++ b/tests/test_worker.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 
 from argparse import Namespace
+from uuid import UUID
 
 from arkindex_worker.cache import (
     CachedClassification,
@@ -13,14 +14,14 @@ from arkindex_worker.cache import (
 from worker_generic_training_dataset.worker import DatasetExtractor
 
 
-def test_process_split(tmp_path):
+def test_process_split(tmp_path, downloaded_images):
     # Parent is train folder
-    parent_id = "a0c4522d-2d80-4766-a01c-b9d686f41f6a"
+    parent_id: UUID = UUID("a0c4522d-2d80-4766-a01c-b9d686f41f6a")
 
     worker = DatasetExtractor()
-    # Mock important configuration steps
-    worker.args = Namespace(dev=False)
-    worker.initialize_database()
+    # Parse some arguments
+    worker.args = Namespace(dev=False, database=None)
+    worker.configure_cache()
     worker.cached_images = dict()
 
     # Where to save the downloaded images
@@ -28,44 +29,99 @@ def test_process_split(tmp_path):
 
     worker.process_split("train", parent_id)
 
+    # Should have created 20 elements in total
+    assert CachedElement.select().count() == 20
+
     # Should have created two pages under root folder
     assert (
         CachedElement.select().where(CachedElement.parent_id == parent_id).count() == 2
     )
 
+    first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c")
+    second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f")
+
     # Should have created 8 text_lines under first page
     assert (
-        CachedElement.select()
-        .where(CachedElement.parent_id == "e26e6803-18da-4768-be30-a0a68132107c")
-        .count()
+        CachedElement.select().where(CachedElement.parent_id == first_page_id).count()
         == 8
     )
-
-    # Should have created one classification linked to first page
+    # Should have created 9 text_lines under second page
     assert (
-        CachedClassification.select()
-        .where(CachedClassification.element == "e26e6803-18da-4768-be30-a0a68132107c")
-        .count()
-        == 1
+        CachedElement.select().where(CachedElement.parent_id == second_page_id).count()
+        == 9
     )
 
+    # Should have created one classification
+    assert CachedClassification.select().count() == 1
+    # Check classification
+    classif = CachedClassification.get_by_id("ea80bc43-6e96-45a1-b6d7-70f29b5871a6")
+    assert classif.element.id == first_page_id
+    assert classif.class_name == "Hello darkness"
+    assert classif.confidence == 1.0
+    assert classif.state == "validated"
+    assert classif.worker_run_id is None
+
     # Should have created two images
     assert CachedImage.select().count() == 2
+    first_image_id = UUID("80a84b30-1ae1-4c13-95d6-7d0d8ee16c51")
+    second_image_id = UUID("e3c755f2-0e1c-468e-ae4c-9206f0fd267a")
+    # Check images
+    for image_id, page_name in (
+        (first_image_id, "g06-018m"),
+        (second_image_id, "c02-026"),
+    ):
+        page_image = CachedImage.get_by_id(image_id)
+        assert page_image.width == 2479
+        assert page_image.height == 3542
+        assert (
+            page_image.url
+            == f"https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2F{page_name}.png"
+        )
+
     assert sorted(tmp_path.rglob("*")) == [
-        tmp_path / "80a84b30-1ae1-4c13-95d6-7d0d8ee16c51.jpg",
-        tmp_path / "e3c755f2-0e1c-468e-ae4c-9206f0fd267a.jpg",
+        tmp_path / f"{first_image_id}.jpg",
+        tmp_path / f"{second_image_id}.jpg",
     ]
 
-    # Should have created a transcription linked to first line of first page
-    assert (
-        CachedTranscription.select()
-        .where(CachedTranscription.element == "6411b293-dee0-4002-ac82-ffc5cdcb499d")
-        .count()
-        == 1
+    # Should have created 17 transcriptions
+    assert CachedTranscription.select().count() == 17
+    # Check transcription of first line on first page
+    transcription: CachedTranscription = CachedTranscription.get_by_id(
+        "144974c3-2477-4101-a486-6c276b0070aa"
     )
+    assert transcription.element.id == UUID("6411b293-dee0-4002-ac82-ffc5cdcb499d")
+    assert transcription.text == "When the sailing season was past , he sent Pearl"
+    assert transcription.confidence == 1.0
+    assert transcription.orientation == "horizontal-lr"
+    assert transcription.worker_version_id is None
+    assert transcription.worker_run_id is None
 
     # Should have created a transcription entity linked to transcription of first line of first page
+    # Checks on entity
     assert CachedEntity.select().count() == 1
-    assert CachedTranscriptionEntity.select().count() == 1
+    entity: CachedEntity = CachedEntity.get_by_id(
+        "e04b0323-0dda-4f76-b218-af40f7d40c84"
+    )
+    assert entity.name == "When the sailing season"
+    assert entity.type == "something something"
+    assert entity.metas == "{}"
+    assert entity.validated is False
+    assert entity.worker_run_id is None
 
-    assert CachedEntity.get_by_id("e04b0323-0dda-4f76-b218-af40f7d40c84")
+    # Checks on transcription entity
+    assert CachedTranscriptionEntity.select().count() == 1
+    tr_entity: CachedTranscriptionEntity = (
+        CachedTranscriptionEntity.select()
+        .where(
+            (
+                CachedTranscriptionEntity.transcription
+                == "144974c3-2477-4101-a486-6c276b0070aa"
+            )
+            & (CachedTranscriptionEntity.entity == entity)
+        )
+        .get()
+    )
+    assert tr_entity.offset == 0
+    assert tr_entity.length == 23
+    assert tr_entity.confidence == 1.0
+    assert tr_entity.worker_run_id is None
diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py
index bb18f7e..8ae47d6 100644
--- a/worker_generic_training_dataset/db.py
+++ b/worker_generic_training_dataset/db.py
@@ -1,5 +1,7 @@
 # -*- coding: utf-8 -*-
 
+from uuid import UUID
+
 from arkindex_export import Classification
 from arkindex_export.models import (
     Element,
@@ -8,25 +10,24 @@ from arkindex_export.models import (
     Transcription,
     TranscriptionEntity,
 )
-from arkindex_worker.cache import CachedElement, CachedTranscription
 
 
-def retrieve_element(element_id: str):
+def retrieve_element(element_id: UUID) -> Element:
     return Element.get_by_id(element_id)
 
 
-def list_classifications(element_id: str):
-    return Classification.select().where(Classification.element_id == element_id)
+def list_classifications(element_id: UUID):
+    return Classification.select().where(Classification.element == element_id)
 
 
-def list_transcriptions(element: CachedElement):
-    return Transcription.select().where(Transcription.element_id == element.id)
+def list_transcriptions(element_id: UUID):
+    return Transcription.select().where(Transcription.element == element_id)
 
 
-def list_transcription_entities(transcription: CachedTranscription):
+def list_transcription_entities(transcription_id: UUID):
     return (
         TranscriptionEntity.select()
-        .where(TranscriptionEntity.transcription_id == transcription.id)
+        .where(TranscriptionEntity.transcription == transcription_id)
         .join(Entity, on=TranscriptionEntity.entity)
         .join(EntityType, on=Entity.type)
     )
diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py
index 76c2bca..34d9760 100644
--- a/worker_generic_training_dataset/utils.py
+++ b/worker_generic_training_dataset/utils.py
@@ -1,23 +1,21 @@
 # -*- coding: utf-8 -*-
-import ast
+import json
 import logging
+from logging import Logger
 from urllib.parse import urljoin
 
-logger = logging.getLogger(__name__)
+from arkindex_worker.image import BoundingBox, polygon_bounding_box
 
+logger: Logger = logging.getLogger(__name__)
 
-def bounding_box(polygon: list):
-    """
-    Returns a 4-tuple (x, y, width, height) for the bounding box of a Polygon (list of points)
-    """
-    all_x, all_y = zip(*polygon)
-    x, y = min(all_x), min(all_y)
-    width, height = max(all_x) - x, max(all_y) - y
-    return int(x), int(y), int(width), int(height)
 
-
-def build_image_url(element):
-    x, y, width, height = bounding_box(json.loads(element.polygon))
+def build_image_url(element) -> str:
+    bbox: BoundingBox = polygon_bounding_box(json.loads(element.polygon))
+    x: int
+    y: int
+    width: int
+    height: int
+    x, y, width, height = bbox
     return urljoin(
         element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg"
     )
diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py
index 88a16a6..7c20093 100644
--- a/worker_generic_training_dataset/worker.py
+++ b/worker_generic_training_dataset/worker.py
@@ -2,7 +2,9 @@
 import logging
 import operator
 import tempfile
+from argparse import Namespace
 from pathlib import Path
+from tempfile import _TemporaryFileWrapper
 from typing import List, Optional
 from uuid import UUID
 
@@ -32,15 +34,15 @@ from worker_generic_training_dataset.db import (
 )
 from worker_generic_training_dataset.utils import build_image_url
 
-logger = logging.getLogger(__name__)
+logger: logging.Logger = logging.getLogger(__name__)
 
 BULK_BATCH_SIZE = 50
 DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
 
 
 class DatasetExtractor(BaseWorker):
-    def configure(self):
-        self.args = self.parser.parse_args()
+    def configure(self) -> None:
+        self.args: Namespace = self.parser.parse_args()
         if self.is_read_only:
             super().configure_for_developers()
         else:
@@ -57,7 +59,7 @@ class DatasetExtractor(BaseWorker):
         self.download_latest_export()
 
         # Initialize db that will be written
-        self.initialize_database()
+        self.configure_cache()
 
         # CachedImage downloaded and created in DB
         self.cached_images = dict()
@@ -66,7 +68,7 @@ class DatasetExtractor(BaseWorker):
         self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data"))
         logger.info(f"Images will be saved at `{self.image_folder}`.")
 
-    def read_training_related_information(self):
+    def read_training_related_information(self) -> None:
         """
         Read from process information
             - train_folder_id
@@ -84,23 +86,26 @@ class DatasetExtractor(BaseWorker):
         self.validation_folder_id = UUID(val_folder_id)
 
         test_folder_id = self.config.get("test_folder_id")
-        self.testing_folder_id = UUID(test_folder_id) if test_folder_id else None
+        self.testing_folder_id: UUID | None = (
+            UUID(test_folder_id) if test_folder_id else None
+        )
 
-    def initialize_database(self):
+    def configure_cache(self) -> None:
         """
         Create an SQLite database compatible with base-worker cache and initialize it.
         """
-        database_path = self.work_dir / "db.sqlite"
+        self.use_cache = True
+        self.cache_path: Path = self.args.database or self.work_dir / "db.sqlite"
         # Remove previous execution result if present
-        database_path.unlink(missing_ok=True)
+        self.cache_path.unlink(missing_ok=True)
 
-        init_cache_db(database_path)
+        init_cache_db(self.cache_path)
 
         create_version_table()
 
         create_tables()
 
-    def download_latest_export(self):
+    def download_latest_export(self) -> None:
         """
         Download the latest export of the current corpus.
         Export must be in `"done"` state.
@@ -119,7 +124,7 @@ class DatasetExtractor(BaseWorker):
             raise e
 
         # Find the latest that is in "done" state
-        exports = sorted(
+        exports: List[dict] = sorted(
             list(filter(lambda exp: exp["state"] == "done", exports)),
             key=operator.itemgetter("updated"),
             reverse=True,
@@ -128,11 +133,11 @@ class DatasetExtractor(BaseWorker):
             len(exports) > 0
         ), f"No available exports found for the corpus {self.corpus_id}."
 
-        # Download latest it in a tmpfile
+        # Download latest export
         try:
-            export_id = exports[0]["id"]
+            export_id: str = exports[0]["id"]
             logger.info(f"Downloading export ({export_id})...")
-            self.export = self.api_client.request(
+            self.export: _TemporaryFileWrapper = self.api_client.request(
                 "DownloadExport",
                 id=export_id,
             )
@@ -146,7 +151,7 @@ class DatasetExtractor(BaseWorker):
 
     def insert_classifications(self, element: CachedElement) -> None:
         logger.info("Listing classifications")
-        classifications = [
+        classifications: list[CachedClassification] = [
             CachedClassification(
                 id=classification.id,
                 element=element,
@@ -168,7 +173,7 @@ class DatasetExtractor(BaseWorker):
         self, element: CachedElement
     ) -> List[CachedTranscription]:
         logger.info("Listing transcriptions")
-        transcriptions = [
+        transcriptions: list[CachedTranscription] = [
             CachedTranscription(
                 id=transcription.id,
                 element=element,
@@ -180,7 +185,7 @@ class DatasetExtractor(BaseWorker):
                 if transcription.worker_version
                 else None,
             )
-            for transcription in list_transcriptions(element)
+            for transcription in list_transcriptions(element.id)
         ]
         if transcriptions:
             logger.info(f"Inserting {len(transcriptions)} transcription(s)")
@@ -193,9 +198,10 @@ class DatasetExtractor(BaseWorker):
 
     def insert_entities(self, transcriptions: List[CachedTranscription]) -> None:
         logger.info("Listing entities")
-        extracted_entities = []
+        entities: List[CachedEntity] = []
+        transcription_entities: List[CachedTranscriptionEntity] = []
         for transcription in transcriptions:
-            for transcription_entity in list_transcription_entities(transcription):
+            for transcription_entity in list_transcription_entities(transcription.id):
                 entity = CachedEntity(
                     id=transcription_entity.entity.id,
                     type=transcription_entity.entity.type.name,
@@ -203,17 +209,15 @@ class DatasetExtractor(BaseWorker):
                     validated=transcription_entity.entity.validated,
                     metas=transcription_entity.entity.metas,
                 )
-                extracted_entities.append(
-                    (
-                        entity,
-                        CachedTranscriptionEntity(
-                            id=transcription_entity.id,
-                            transcription=transcription,
-                            entity=entity,
-                            offset=transcription_entity.offset,
-                            length=transcription_entity.length,
-                            confidence=transcription_entity.confidence,
-                        ),
+                entities.append(entity)
+                transcription_entities.append(
+                    CachedTranscriptionEntity(
+                        id=transcription_entity.id,
+                        transcription=transcription,
+                        entity=entity,
+                        offset=transcription_entity.offset,
+                        length=transcription_entity.length,
+                        confidence=transcription_entity.confidence,
                     )
                 )
         if entities:
@@ -236,17 +240,19 @@ class DatasetExtractor(BaseWorker):
                     batch_size=BULK_BATCH_SIZE,
                 )
 
-    def insert_element(self, element: Element, parent_id: Optional[str] = None):
+    def insert_element(
+        self, element: Element, parent_id: Optional[UUID] = None
+    ) -> None:
         """
-        Insert the given element's children in the cache database.
-        Their image will also be saved to disk, if they weren't already.
+        Insert the given element in the cache database.
+        Its image will also be saved to disk, if it wasn't already.
 
         The insertion of an element includes:
         - its classifications
         - its transcriptions
         - its transcriptions' entities (both Entity and TranscriptionEntity)
 
-        :param element: Element to insert. All its children will be inserted as well.
+        :param element: Element to insert.
         :param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
         """
         logger.info(f"Processing element ({element.id})")
@@ -259,41 +265,45 @@ class DatasetExtractor(BaseWorker):
             # Insert image
             logger.info("Inserting image")
             # Store images in case some other elements use it as well
-            self.cached_images[element.image_id] = CachedImage.create(
-                id=element.image.id,
-                width=element.image.width,
-                height=element.image.height,
-                url=element.image.url,
-            )
+            with cache_database.atomic():
+                self.cached_images[element.image.id] = CachedImage.create(
+                    id=element.image.id,
+                    width=element.image.width,
+                    height=element.image.height,
+                    url=element.image.url,
+                )
 
         # Insert element
         logger.info("Inserting element")
-        cached_element = CachedElement.create(
-            id=element.id,
-            parent_id=parent_id,
-            type=element.type,
-            image=self.cached_images[element.image.id] if element.image else None,
-            polygon=element.polygon,
-            rotation_angle=element.rotation_angle,
-            mirrored=element.mirrored,
-            worker_version_id=element.worker_version.id
-            if element.worker_version
-            else None,
-            confidence=element.confidence,
-        )
+        with cache_database.atomic():
+            cached_element: CachedElement = CachedElement.create(
+                id=element.id,
+                parent_id=parent_id,
+                type=element.type,
+                image=self.cached_images[element.image.id] if element.image else None,
+                polygon=element.polygon,
+                rotation_angle=element.rotation_angle,
+                mirrored=element.mirrored,
+                worker_version_id=element.worker_version.id
+                if element.worker_version
+                else None,
+                confidence=element.confidence,
+            )
 
         # Insert classifications
         self.insert_classifications(cached_element)
 
         # Insert transcriptions
-        transcriptions = self.insert_transcriptions(cached_element)
+        transcriptions: List[CachedTranscription] = self.insert_transcriptions(
+            cached_element
+        )
 
         # Insert entities
         self.insert_entities(transcriptions)
 
-    def process_split(self, split_name: str, split_id: UUID):
+    def process_split(self, split_name: str, split_id: UUID) -> None:
         """
-        Insert all elements under the given parent folder.
+        Insert all elements under the given parent folder (all queries are recursive).
         - `page` elements are linked to this folder (via parent_id foreign key)
         - `page` element children are linked to their `page` parent (via parent_id foreign key)
         """
@@ -302,12 +312,12 @@ class DatasetExtractor(BaseWorker):
         )
         # Fill cache
         # Retrieve parent and create parent
-        parent = retrieve_element(split_id)
+        parent: Element = retrieve_element(split_id)
         self.insert_element(parent)
 
         # First list all pages
         pages = list_children(split_id).join(Image).where(Element.type == "page")
-        nb_pages = pages.count()
+        nb_pages: int = pages.count()
         for idx, page in enumerate(pages, start=1):
             logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})")
 
@@ -316,7 +326,7 @@ class DatasetExtractor(BaseWorker):
 
             # List children
             children = list_children(page.id)
-            nb_children = children.count()
+            nb_children: int = children.count()
             for child_idx, child in enumerate(children, start=1):
                 logger.info(f"Processing child ({child_idx}/{nb_children})")
                 # Insert child
@@ -336,7 +346,7 @@ class DatasetExtractor(BaseWorker):
             self.process_split(split_name, split_id)
 
         # TAR + ZSTD Image folder and store as task artifact
-        zstd_archive_path = self.work_dir / "arkindex_data.zstd"
+        zstd_archive_path: Path = self.work_dir / "arkindex_data.zstd"
         logger.info(f"Compressing the images to {zstd_archive_path}")
         create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path)
 
-- 
GitLab