# -*- coding: utf-8 -*- import json import uuid from operator import itemgetter from typing import List import pytest from arkindex_export import ( Element, ElementPath, Entity, EntityType, Image, ImageServer, Transcription, TranscriptionEntity, WorkerRun, WorkerVersion, database, ) from tests import FIXTURES @pytest.fixture(scope="session") def mock_database(tmp_path_factory): def create_transcription_entity( transcription: Transcription, worker_version: str | None, type: str, name: str, offset: int, ) -> None: entity_type, _ = EntityType.get_or_create( name=type, defaults={"id": f"{type}_id"} ) entity = Entity.create( id=str(uuid.uuid4()), name=name, type=entity_type, worker_version=worker_version, ) TranscriptionEntity.create( entity=entity, length=len(name), offset=offset, transcription=transcription, worker_version=worker_version, ) def create_transcriptions(element: Element, entities: List[dict]) -> None: if not entities: return # Add transcription with entities entities = sorted(entities, key=itemgetter("offset")) # We will add extra spaces to test the "keep_spaces" parameters of the "extract" command for offset, entity in enumerate(entities[1:], start=1): entity["offset"] += offset for worker_version in [None, "worker_version_id"]: # Use different transcriptions to filter by worker version if worker_version == "worker_version_id": for entity in entities: entity["name"] = entity["name"].lower() transcription = Transcription.create( id=element.id + (worker_version or ""), # Add extra spaces to test the "keep_spaces" parameters of the "extract" command text=" ".join(map(itemgetter("name"), entities)), element=element, worker_version=worker_version, ) for entity in entities: create_transcription_entity( transcription=transcription, worker_version=worker_version, **entity, ) def create_element(id: str, parent: Element | None = None) -> None: element_path = (FIXTURES / "extraction" / "elements" / id).with_suffix(".json") element_json = json.loads(element_path.read_text()) element_type = element_json["type"] image_path = ( FIXTURES / "extraction" / "images" / element_type / id ).with_suffix(".jpg") polygon = element_json.get("polygon") # Always use page images because polygons are based on the full image image, _ = ( Image.get_or_create( id=id + "-image", defaults={ "server": image_server, # Use path to image instead of actual URL since we won't be doing any download "url": image_path, "width": 0, "height": 0, }, ) if polygon else (None, False) ) element = Element.create( id=id, name=id, type=element_type, image=image, polygon=json.dumps(polygon) if polygon else None, created=0.0, updated=0.0, ) if parent: ElementPath.create(id=str(uuid.uuid4()), parent=parent, child=element) create_transcriptions( element=element, entities=element_json.get("transcription_entities", []), ) # Recursive function to create children for child in element_json.get("children", []): create_element(id=child, parent=element) MODELS = [ WorkerVersion, WorkerRun, ImageServer, Image, Element, ElementPath, EntityType, Entity, Transcription, TranscriptionEntity, ] # Initialisation tmp_path = tmp_path_factory.mktemp("data") database_path = tmp_path / "db.sqlite" database.init( database_path, pragmas={ # Recommended settings from peewee # http://docs.peewee-orm.com/en/latest/peewee/database.html#recommended-settings # Do not set journal mode to WAL as it writes in the database "cache_size": -1 * 64000, # 64MB "foreign_keys": 1, "ignore_check_constraints": 0, "synchronous": 0, }, ) database.connect() # Create tables database.create_tables(MODELS) image_server = ImageServer.create( url="http://image/server/url", display_name="Image server", ) WorkerVersion.create( id="worker_version_id", slug="worker_version", name="Worker version", repository_url="http://repository/url", revision="main", type="worker", ) # Create folders create_element(id="root") # Create data for entities extraction tests # Create transcription transcription = Transcription.create( id="tr-with-entities", text="The great king Charles III has eaten \nwith us.", element=Element.select().first(), ) WorkerVersion.bulk_create( [ WorkerVersion( id=f"{nestation}-id", slug=nestation, name=nestation, repository_url="http://repository/url", revision="main", type="worker", ) for nestation in ("nested", "non-nested") ] ) entities = [ # Non-nested entities { "worker_version": "non-nested-id", "type": "adj", "name": "great", "offset": 4, }, { "worker_version": "non-nested-id", "type": "name", "name": "Charles", "offset": 15, }, { "worker_version": "non-nested-id", "type": "person", "name": "us", "offset": 43, }, # Nested entities { "worker_version": "nested-id", "type": "fullname", "name": "Charles III", "offset": 15, }, { "worker_version": "nested-id", "type": "name", "name": "Charles", "offset": 15, }, { "worker_version": "nested-id", "type": "person", "name": "us", "offset": 43, }, ] for entity in entities: create_transcription_entity(transcription=transcription, **entity) return database_path @pytest.fixture def training_config(): return json.loads((FIXTURES.parent.parent / "configs" / "tests.json").read_text()) @pytest.fixture def evaluate_config(): return json.loads((FIXTURES.parent.parent / "configs" / "eval.json").read_text()) @pytest.fixture def split_content(): splits = json.loads((FIXTURES / "extraction" / "split.json").read_text()) for split in splits: for element_id in splits[split]: splits[split][element_id]["image"]["iiif_url"] = splits[split][element_id][ "image" ]["iiif_url"].replace("{FIXTURES}", str(FIXTURES)) return splits