diff --git a/README.md b/README.md index af907a18f989a703b0039b869f4abd64c2a7f243..e477148e65cbbfef2be64094cde549cc501bfd83 100644 --- a/README.md +++ b/README.md @@ -30,3 +30,5 @@ tox ``` To recreate tox virtual environment (e.g. a dependencies update), you may run `tox -r` + +Tests use an export from [`IAM | GT | DLA`](https://preprod.arkindex.teklia.com/browse/53919f29-f79c-4c3e-b088-0e955950879f?top_level=true&folder=true). diff --git a/ci/build.sh b/ci/build.sh index f29f50f27b88056216e6f880bb713012fc4e9956..7c43de460fe5ad1e9b5eeb593feb9c1a9db145f3 100755 --- a/ci/build.sh +++ b/ci/build.sh @@ -5,6 +5,7 @@ # Will automatically login to a registry if CI_REGISTRY, CI_REGISTRY_USER and CI_REGISTRY_PASSWORD are set. # Will only push an image if $CI_REGISTRY is set. +VERSION="POC" if [ -z "$VERSION" ]; then VERSION=${CI_COMMIT_TAG:-latest} fi diff --git a/tests/conftest.py b/tests/conftest.py index 14dfec177864ca1da9e6470d5bcb84f1a51e530d..0121c272d8c5ef29181e34e3b6cb88cc871a84f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,15 @@ # -*- coding: utf-8 -*- import os +from pathlib import Path import pytest from arkindex.mock import MockApiClient +from arkindex_export import open_database from arkindex_worker.worker.base import BaseWorker +DATA_PATH = Path(__file__).parent / "data" + @pytest.fixture(autouse=True) def setup_environment(responses, monkeypatch): @@ -26,3 +30,8 @@ def setup_environment(responses, monkeypatch): # Setup a mock api client instead of using a real one monkeypatch.setattr(BaseWorker, "setup_api_client", lambda _: MockApiClient()) + + +@pytest.fixture(scope="session", autouse=True) +def arkindex_db(): + open_database(DATA_PATH / "arkindex_export.sqlite") diff --git a/tests/test_worker.py b/tests/test_worker.py index 220ba5127476f0e53acc00a3295f90dea2293ff3..3c3703f63277493da0e2b6ab75ec3757f3e70bd0 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -1,5 +1,73 @@ # -*- coding: utf-8 -*- +import tempfile +from argparse import Namespace +from pathlib import Path -def test_dummy(): - assert True +from arkindex_worker.cache import ( + CachedClassification, + CachedElement, + CachedEntity, + CachedImage, + CachedTranscription, + CachedTranscriptionEntity, +) +from worker_generic_training_dataset.worker import DatasetExtractor + + +def test_process_split(): + # Parent is train folder + parent_id = "a0c4522d-2d80-4766-a01c-b9d686f41f6a" + + worker = DatasetExtractor() + # Mock important configuration steps + worker.args = Namespace(dev=False) + worker.initialize_database() + worker.cached_images = dict() + + # Where to save the downloaded images + image_folder = Path(tempfile.mkdtemp()) + + worker.process_split("train", parent_id, image_folder) + + # Should have created two pages under root folder + assert ( + CachedElement.select().where(CachedElement.parent_id == parent_id).count() == 2 + ) + + # Should have created 8 text_lines under first page + assert ( + CachedElement.select() + .where(CachedElement.parent_id == "e26e6803-18da-4768-be30-a0a68132107c") + .count() + == 8 + ) + + # Should have created one classification linked to first page + assert ( + CachedClassification.select() + .where(CachedClassification.element == "e26e6803-18da-4768-be30-a0a68132107c") + .count() + == 1 + ) + + # Should have created two images + assert CachedImage.select().count() == 2 + assert sorted(image_folder.rglob("*")) == [ + image_folder / "80a84b30-1ae1-4c13-95d6-7d0d8ee16c51.jpg", + image_folder / "e3c755f2-0e1c-468e-ae4c-9206f0fd267a.jpg", + ] + + # Should have created a transcription linked to first line of first page + assert ( + CachedTranscription.select() + .where(CachedTranscription.element == "6411b293-dee0-4002-ac82-ffc5cdcb499d") + .count() + == 1 + ) + + # Should have created a transcription entity linked to transcription of first line of first page + assert CachedEntity.select().count() == 1 + assert CachedTranscriptionEntity.select().count() == 1 + + assert CachedEntity.get_by_id("e04b0323-0dda-4f76-b218-af40f7d40c84") diff --git a/worker_generic_training_dataset/db.py b/worker_generic_training_dataset/db.py index a608ab93e5cfe626b273d022966282067a9fab89..6d413747f83a374b571653e950e175be8bb317fd 100644 --- a/worker_generic_training_dataset/db.py +++ b/worker_generic_training_dataset/db.py @@ -25,8 +25,7 @@ def retrieve_element(element_id: str): def list_classifications(element_id: str): - query = Classification.select().where(Classification.element_id == element_id) - return query + return Classification.select().where(Classification.element_id == element_id) def parse_transcription(transcription: NamedTuple, element: CachedElement): diff --git a/worker_generic_training_dataset/worker.py b/worker_generic_training_dataset/worker.py index 30120a142d8c518169d5c14acecafb46f266446f..6dc07181a5c72f7b56816847496b58cad0f284e0 100644 --- a/worker_generic_training_dataset/worker.py +++ b/worker_generic_training_dataset/worker.py @@ -76,11 +76,11 @@ class DatasetExtractor(BaseWorker): """ logger.info("Retrieving information from process_information") - train_folder_id = self.config.get("train_folder_id") + train_folder_id = self.process_information.get("train_folder_id") assert train_folder_id, "A training folder id is necessary to use this worker" self.training_folder_id = UUID(train_folder_id) - val_folder_id = self.config.get("validation_folder_id") + val_folder_id = self.process_information.get("validation_folder_id") assert val_folder_id, "A validation folder id is necessary to use this worker" self.validation_folder_id = UUID(val_folder_id)