diff --git a/.arkindex.yml b/.arkindex.yml index 1f435a4bbdf7f6476795a1af19d462a5ac1b4331..a14027b7d3609537a193062ed6dcdc2096534ade 100644 --- a/.arkindex.yml +++ b/.arkindex.yml @@ -9,3 +9,16 @@ workers: type: data-extract docker: build: Dockerfile + user_configuration: + train_folder_id: + type: string + title: ID of the training folder on Arkindex + required: true + validation_folder_id: + type: string + title: ID of the validation folder on Arkindex + required: true + test_folder_id: + type: string + title: ID of the testing folder on Arkindex + required: true diff --git a/README.md b/README.md index af907a18f989a703b0039b869f4abd64c2a7f243..e477148e65cbbfef2be64094cde549cc501bfd83 100644 --- a/README.md +++ b/README.md @@ -30,3 +30,5 @@ tox ``` To recreate tox virtual environment (e.g. a dependencies update), you may run `tox -r` + +Tests use an export from [`IAM | GT | DLA`](https://preprod.arkindex.teklia.com/browse/53919f29-f79c-4c3e-b088-0e955950879f?top_level=true&folder=true). diff --git a/requirements.txt b/requirements.txt index 5ff1be6e274e323aa1ab7d829d7653efed63ff8c..918e5dcd3ef8e9b1acc437a056e2a7ae70eabf64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -arkindex-base-worker==0.3.2 +arkindex-base-worker==0.3.3-rc3 +arkindex-export==0.1.2 diff --git a/tests/conftest.py b/tests/conftest.py index 14dfec177864ca1da9e6470d5bcb84f1a51e530d..98a69b3ce3c72cf6e3adc07319ae599225ea06d7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,15 @@ # -*- coding: utf-8 -*- import os +from pathlib import Path import pytest from arkindex.mock import MockApiClient +from arkindex_export import open_database from arkindex_worker.worker.base import BaseWorker +DATA_PATH: Path = Path(__file__).parent / "data" + @pytest.fixture(autouse=True) def setup_environment(responses, monkeypatch): @@ -26,3 +30,37 @@ def setup_environment(responses, monkeypatch): # Setup a mock api client instead of using a real one monkeypatch.setattr(BaseWorker, "setup_api_client", lambda _: MockApiClient()) + + +@pytest.fixture(scope="session", autouse=True) +def arkindex_db() -> None: + open_database(DATA_PATH / "arkindex_export.sqlite") + + +@pytest.fixture +def page_1_image() -> bytes: + with (DATA_PATH / "sample_image_1.jpg").open("rb") as image: + return image.read() + + +@pytest.fixture +def page_2_image() -> bytes: + with (DATA_PATH / "sample_image_2.jpg").open("rb") as image: + return image.read() + + +@pytest.fixture +def downloaded_images(responses, page_1_image, page_2_image) -> None: + # Mock image download call + # Page 1 + responses.add( + responses.GET, + "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2Fg06-018m.png/0,0,2479,3542/full/0/default.jpg", + body=page_1_image, + ) + # Page 2 + responses.add( + responses.GET, + "https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2Fc02-026.png/0,0,2479,3542/full/0/default.jpg", + body=page_2_image, + ) diff --git a/tests/data/arkindex_export.sqlite b/tests/data/arkindex_export.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..bfefcefb2c4db70f61787d85ecea51e62c4ad8cf Binary files /dev/null and b/tests/data/arkindex_export.sqlite differ diff --git a/tests/data/sample_image_1.jpg b/tests/data/sample_image_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7645d932552430fb933ec7bc567137e12a4b5539 Binary files /dev/null and b/tests/data/sample_image_1.jpg differ diff --git a/tests/data/sample_image_2.jpg b/tests/data/sample_image_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..16bad06f5f0bf45ea04fb8f3064d44f71f5c1f39 Binary files /dev/null and b/tests/data/sample_image_2.jpg differ diff --git a/tests/test_worker.py b/tests/test_worker.py index bb38787b81cbba08837e9c16035209e7c37fdcdd..b7fbbdb694658d8c661af36c86e574516490fa7e 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -1,13 +1,127 @@ # -*- coding: utf-8 -*- -import importlib +from argparse import Namespace +from uuid import UUID -def test_dummy(): - assert True +from arkindex_worker.cache import ( + CachedClassification, + CachedElement, + CachedEntity, + CachedImage, + CachedTranscription, + CachedTranscriptionEntity, +) +from worker_generic_training_dataset.worker import DatasetExtractor -def test_import(): - """Import our newly created module, through importlib to avoid parsing issues""" - worker = importlib.import_module("worker_generic_training_dataset.worker") - assert hasattr(worker, "Demo") - assert hasattr(worker.Demo, "process_element") +def test_process_split(tmp_path, downloaded_images): + # Parent is train folder + parent_id: UUID = UUID("a0c4522d-2d80-4766-a01c-b9d686f41f6a") + + worker = DatasetExtractor() + # Parse some arguments + worker.args = Namespace(database=None) + worker.configure_cache() + worker.cached_images = dict() + + # Where to save the downloaded images + worker.image_folder = tmp_path + + worker.process_split("train", parent_id) + + # Should have created 20 elements in total + assert CachedElement.select().count() == 20 + + # Should have created two pages under root folder + assert ( + CachedElement.select().where(CachedElement.parent_id == parent_id).count() == 2 + ) + + first_page_id = UUID("e26e6803-18da-4768-be30-a0a68132107c") + second_page_id = UUID("c673bd94-96b1-4a2e-8662-a4d806940b5f") + + # Should have created 8 text_lines under first page + assert ( + CachedElement.select().where(CachedElement.parent_id == first_page_id).count() + == 8 + ) + # Should have created 9 text_lines under second page + assert ( + CachedElement.select().where(CachedElement.parent_id == second_page_id).count() + == 9 + ) + + # Should have created one classification + assert CachedClassification.select().count() == 1 + # Check classification + classif = CachedClassification.get_by_id("ea80bc43-6e96-45a1-b6d7-70f29b5871a6") + assert classif.element.id == first_page_id + assert classif.class_name == "Hello darkness" + assert classif.confidence == 1.0 + assert classif.state == "validated" + assert classif.worker_run_id is None + + # Should have created two images + assert CachedImage.select().count() == 2 + first_image_id = UUID("80a84b30-1ae1-4c13-95d6-7d0d8ee16c51") + second_image_id = UUID("e3c755f2-0e1c-468e-ae4c-9206f0fd267a") + # Check images + for image_id, page_name in ( + (first_image_id, "g06-018m"), + (second_image_id, "c02-026"), + ): + page_image = CachedImage.get_by_id(image_id) + assert page_image.width == 2479 + assert page_image.height == 3542 + assert ( + page_image.url + == f"https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2F{page_name}.png" + ) + + assert sorted(tmp_path.rglob("*")) == [ + tmp_path / f"{first_image_id}.jpg", + tmp_path / f"{second_image_id}.jpg", + ] + + # Should have created 17 transcriptions + assert CachedTranscription.select().count() == 17 + # Check transcription of first line on first page + transcription: CachedTranscription = CachedTranscription.get_by_id( + "144974c3-2477-4101-a486-6c276b0070aa" + ) + assert transcription.element.id == UUID("6411b293-dee0-4002-ac82-ffc5cdcb499d") + assert transcription.text == "When the sailing season was past , he sent Pearl" + assert transcription.confidence == 1.0 + assert transcription.orientation == "horizontal-lr" + assert transcription.worker_version_id is None + assert transcription.worker_run_id is None + + # Should have created a transcription entity linked to transcription of first line of first page + # Checks on entity + assert CachedEntity.select().count() == 1 + entity: CachedEntity = CachedEntity.get_by_id( + "e04b0323-0dda-4f76-b218-af40f7d40c84" + ) + assert entity.name == "When the sailing season" + assert entity.type == "something something" + assert entity.metas == "{}" + assert entity.validated is False + assert entity.worker_run_id is None + + # Checks on transcription entity + assert CachedTranscriptionEntity.select().count() == 1 + tr_entity: CachedTranscriptionEntity = ( + CachedTranscriptionEntity.select() + .where( + ( + CachedTranscriptionEntity.transcription + == "144974c3-2477-4101-a486-6c276b0070aa" + ) + & (CachedTranscriptionEntity.entity == entity) + ) + .get() + ) + assert tr_entity.offset == 0 + assert tr_entity.length == 23 + assert tr_entity.confidence == 1.0 + assert tr_entity.worker_run_id is None diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py new file mode 100644 index 0000000000000000000000000000000000000000..8ae47d646fefec8bb002e2845a8f00750bc24b84 --- /dev/null +++ b/worker_generic_training_dataset/db.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- + +from uuid import UUID + +from arkindex_export import Classification +from arkindex_export.models import ( + Element, + Entity, + EntityType, + Transcription, + TranscriptionEntity, +) + + +def retrieve_element(element_id: UUID) -> Element: + return Element.get_by_id(element_id) + + +def list_classifications(element_id: UUID): + return Classification.select().where(Classification.element == element_id) + + +def list_transcriptions(element_id: UUID): + return Transcription.select().where(Transcription.element == element_id) + + +def list_transcription_entities(transcription_id: UUID): + return ( + TranscriptionEntity.select() + .where(TranscriptionEntity.transcription == transcription_id) + .join(Entity, on=TranscriptionEntity.entity) + .join(EntityType, on=Entity.type) + ) diff --git a/worker_generic_training_dataset/utils.py b/worker_generic_training_dataset/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..34d976027b09d584a2c9d7de62118fb1b5d972f4 --- /dev/null +++ b/worker_generic_training_dataset/utils.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +import json +import logging +from logging import Logger +from urllib.parse import urljoin + +from arkindex_worker.image import BoundingBox, polygon_bounding_box + +logger: Logger = logging.getLogger(__name__) + + +def build_image_url(element) -> str: + bbox: BoundingBox = polygon_bounding_box(json.loads(element.polygon)) + x: int + y: int + width: int + height: int + x, y, width, height = bbox + return urljoin( + element.image.url + "/", f"{x},{y},{width},{height}/full/0/default.jpg" + ) diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 8489d561aeba8354d60155803cdcb9c405d68112..f547e42b86574943bd3c422951b2ab3f88f91374 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -1,15 +1,360 @@ # -*- coding: utf-8 -*- -from arkindex_worker.worker import ElementsWorker +import logging +import operator +import tempfile +from argparse import Namespace +from pathlib import Path +from tempfile import _TemporaryFileWrapper +from typing import List, Optional +from uuid import UUID +from apistar.exceptions import ErrorResponse +from arkindex_export import Element, Image, open_database +from arkindex_export.queries import list_children +from arkindex_worker.cache import ( + CachedClassification, + CachedElement, + CachedEntity, + CachedImage, + CachedTranscription, + CachedTranscriptionEntity, + create_tables, + create_version_table, +) +from arkindex_worker.cache import db as cache_database +from arkindex_worker.cache import init_cache_db +from arkindex_worker.image import download_image +from arkindex_worker.utils import create_tar_zst_archive +from arkindex_worker.worker.base import BaseWorker +from worker_generic_training_dataset.db import ( + list_classifications, + list_transcription_entities, + list_transcriptions, + retrieve_element, +) +from worker_generic_training_dataset.utils import build_image_url -class Demo(ElementsWorker): - def process_element(self, element): - print("Demo processing element", element) +logger: logging.Logger = logging.getLogger(__name__) + +BULK_BATCH_SIZE = 50 +DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr" + + +class DatasetExtractor(BaseWorker): + def configure(self) -> None: + self.args: Namespace = self.parser.parse_args() + if self.is_read_only: + super().configure_for_developers() + else: + super().configure() + + if self.user_configuration: + logger.info("Overriding with user_configuration") + self.config.update(self.user_configuration) + + # Read process information + self.read_training_related_information() + + # Download corpus + self.download_latest_export() + + # Initialize db that will be written + self.configure_cache() + + # CachedImage downloaded and created in DB + self.cached_images = dict() + + # Where to save the downloaded images + self.image_folder = Path(tempfile.mkdtemp(suffix="-arkindex-data")) + logger.info(f"Images will be saved at `{self.image_folder}`.") + + def read_training_related_information(self) -> None: + """ + Read from process information + - train_folder_id + - validation_folder_id + - test_folder_id (optional) + """ + logger.info("Retrieving information from process_information") + + train_folder_id = self.process_information.get("train_folder_id") + assert train_folder_id, "A training folder id is necessary to use this worker" + self.training_folder_id = UUID(train_folder_id) + + val_folder_id = self.process_information.get("validation_folder_id") + assert val_folder_id, "A validation folder id is necessary to use this worker" + self.validation_folder_id = UUID(val_folder_id) + + test_folder_id = self.process_information.get("test_folder_id") + self.testing_folder_id: UUID | None = ( + UUID(test_folder_id) if test_folder_id else None + ) + + def configure_cache(self) -> None: + """ + Create an SQLite database compatible with base-worker cache and initialize it. + """ + self.use_cache = True + self.cache_path: Path = self.args.database or self.work_dir / "db.sqlite" + # Remove previous execution result if present + self.cache_path.unlink(missing_ok=True) + + init_cache_db(self.cache_path) + + create_version_table() + + create_tables() + + def download_latest_export(self) -> None: + """ + Download the latest export of the current corpus. + Export must be in `"done"` state. + """ + try: + exports = list( + self.api_client.paginate( + "ListExports", + id=self.corpus_id, + ) + ) + except ErrorResponse as e: + logger.error( + f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}" + ) + raise e + + # Find the latest that is in "done" state + exports: List[dict] = sorted( + list(filter(lambda exp: exp["state"] == "done", exports)), + key=operator.itemgetter("updated"), + reverse=True, + ) + assert ( + len(exports) > 0 + ), f"No available exports found for the corpus {self.corpus_id}." + + # Download latest export + try: + export_id: str = exports[0]["id"] + logger.info(f"Downloading export ({export_id})...") + self.export: _TemporaryFileWrapper = self.api_client.request( + "DownloadExport", + id=export_id, + ) + logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`") + open_database(self.export.name) + except ErrorResponse as e: + logger.error( + f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}" + ) + raise e + + def insert_classifications(self, element: CachedElement) -> None: + logger.info("Listing classifications") + classifications: list[CachedClassification] = [ + CachedClassification( + id=classification.id, + element=element, + class_name=classification.class_name, + confidence=classification.confidence, + state=classification.state, + ) + for classification in list_classifications(element.id) + ] + if classifications: + logger.info(f"Inserting {len(classifications)} classification(s)") + with cache_database.atomic(): + CachedClassification.bulk_create( + model_list=classifications, + batch_size=BULK_BATCH_SIZE, + ) + + def insert_transcriptions( + self, element: CachedElement + ) -> List[CachedTranscription]: + logger.info("Listing transcriptions") + transcriptions: list[CachedTranscription] = [ + CachedTranscription( + id=transcription.id, + element=element, + text=transcription.text, + # Dodge not-null constraint for now + confidence=transcription.confidence or 1.0, + orientation=DEFAULT_TRANSCRIPTION_ORIENTATION, + worker_version_id=transcription.worker_version.id + if transcription.worker_version + else None, + ) + for transcription in list_transcriptions(element.id) + ] + if transcriptions: + logger.info(f"Inserting {len(transcriptions)} transcription(s)") + with cache_database.atomic(): + CachedTranscription.bulk_create( + model_list=transcriptions, + batch_size=BULK_BATCH_SIZE, + ) + return transcriptions + + def insert_entities(self, transcriptions: List[CachedTranscription]) -> None: + logger.info("Listing entities") + entities: List[CachedEntity] = [] + transcription_entities: List[CachedTranscriptionEntity] = [] + for transcription in transcriptions: + for transcription_entity in list_transcription_entities(transcription.id): + entity = CachedEntity( + id=transcription_entity.entity.id, + type=transcription_entity.entity.type.name, + name=transcription_entity.entity.name, + validated=transcription_entity.entity.validated, + metas=transcription_entity.entity.metas, + ) + entities.append(entity) + transcription_entities.append( + CachedTranscriptionEntity( + id=transcription_entity.id, + transcription=transcription, + entity=entity, + offset=transcription_entity.offset, + length=transcription_entity.length, + confidence=transcription_entity.confidence, + ) + ) + if entities: + # First insert entities since they are foreign keys on transcription entities + logger.info(f"Inserting {len(entities)} entities") + with cache_database.atomic(): + CachedEntity.bulk_create( + model_list=entities, + batch_size=BULK_BATCH_SIZE, + ) + + if transcription_entities: + # Insert transcription entities + logger.info( + f"Inserting {len(transcription_entities)} transcription entities" + ) + with cache_database.atomic(): + CachedTranscriptionEntity.bulk_create( + model_list=transcription_entities, + batch_size=BULK_BATCH_SIZE, + ) + + def insert_element( + self, element: Element, parent_id: Optional[UUID] = None + ) -> None: + """ + Insert the given element in the cache database. + Its image will also be saved to disk, if it wasn't already. + + The insertion of an element includes: + - its classifications + - its transcriptions + - its transcriptions' entities (both Entity and TranscriptionEntity) + + :param element: Element to insert. + :param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements. + """ + logger.info(f"Processing element ({element.id})") + if element.image and element.image.id not in self.cached_images: + # Download image + logger.info("Downloading image") + download_image(url=build_image_url(element)).save( + self.image_folder / f"{element.image.id}.jpg" + ) + # Insert image + logger.info("Inserting image") + # Store images in case some other elements use it as well + with cache_database.atomic(): + self.cached_images[element.image.id] = CachedImage.create( + id=element.image.id, + width=element.image.width, + height=element.image.height, + url=element.image.url, + ) + + # Insert element + logger.info("Inserting element") + with cache_database.atomic(): + cached_element: CachedElement = CachedElement.create( + id=element.id, + parent_id=parent_id, + type=element.type, + image=self.cached_images[element.image.id] if element.image else None, + polygon=element.polygon, + rotation_angle=element.rotation_angle, + mirrored=element.mirrored, + worker_version_id=element.worker_version.id + if element.worker_version + else None, + confidence=element.confidence, + ) + + # Insert classifications + self.insert_classifications(cached_element) + + # Insert transcriptions + transcriptions: List[CachedTranscription] = self.insert_transcriptions( + cached_element + ) + + # Insert entities + self.insert_entities(transcriptions) + + def process_split(self, split_name: str, split_id: UUID) -> None: + """ + Insert all elements under the given parent folder (all queries are recursive). + - `page` elements are linked to this folder (via parent_id foreign key) + - `page` element children are linked to their `page` parent (via parent_id foreign key) + """ + logger.info( + f"Filling the Base-Worker cache with information from children under element ({split_id})" + ) + # Fill cache + # Retrieve parent and create parent + parent: Element = retrieve_element(split_id) + self.insert_element(parent) + + # First list all pages + pages = list_children(split_id).join(Image).where(Element.type == "page") + nb_pages: int = pages.count() + for idx, page in enumerate(pages, start=1): + logger.info(f"Processing `{split_name}` page ({idx}/{nb_pages})") + + # Insert page + self.insert_element(page, parent_id=split_id) + + # List children + children = list_children(page.id) + nb_children: int = children.count() + for child_idx, child in enumerate(children, start=1): + logger.info(f"Processing child ({child_idx}/{nb_children})") + # Insert child + self.insert_element(child, parent_id=page.id) + + def run(self): + self.configure() + + # Iterate over given split + for split_name, split_id in [ + ("Train", self.training_folder_id), + ("Validation", self.validation_folder_id), + ("Test", self.testing_folder_id), + ]: + if not split_id: + continue + self.process_split(split_name, split_id) + + # TAR + ZSTD Image folder and store as task artifact + zstd_archive_path: Path = self.work_dir / "arkindex_data.zstd" + logger.info(f"Compressing the images to {zstd_archive_path}") + create_tar_zst_archive(source=self.image_folder, destination=zstd_archive_path) def main(): - Demo( - description="Fill base-worker cache with information about dataset and extract images" + DatasetExtractor( + description="Fill base-worker cache with information about dataset and extract images", + support_cache=True, ).run()