# Copyright Teklia (contact@teklia.com) & Denis Coquenet # This code is licensed under CeCILL-C # -*- coding: utf-8 -*- from operator import itemgetter import pytest from arkindex_export import Dataset, DatasetElement, Element from dan import TRAIN_NAME from dan.datasets.extract.db import ( get_dataset_elements, get_elements, get_transcription_entities, get_transcriptions, ) def test_get_dataset_elements(mock_database): """ Assert dataset elements retrieval output against verified results """ dataset_elements = get_dataset_elements( dataset=Dataset.select().get(), split=TRAIN_NAME, ) # ID verification assert all( isinstance(dataset_element, DatasetElement) for dataset_element in dataset_elements ) assert [dataset_element.element.id for dataset_element in dataset_elements] == [ "train-page_1", "train-page_2", ] def test_get_elements(mock_database): """ Assert elements retrieval output against verified results """ elements = get_elements( parent_id="train-page_1", element_type=["text_line"], ) # ID verification assert all(isinstance(element, Element) for element in elements) assert [element.id for element in elements] == [ "train-page_1-line_1", "train-page_1-line_2", "train-page_1-line_3", "train-page_1-line_4", ] @pytest.mark.parametrize( "sources", ([False], ["id"], [], [False, "id"]), ) def test_get_transcriptions(sources, mock_database): """ Assert transcriptions retrieval output against verified results """ worker_versions = [ f"worker_version_{source}" if isinstance(source, str) else source for source in sources ] worker_runs = [ f"worker_run_{source}" if isinstance(source, str) else source for source in sources ] element_id = "train-page_1-line_1" transcriptions = get_transcriptions( element_id=element_id, transcription_worker_versions=worker_versions, transcription_worker_runs=worker_runs, ) expected_transcriptions = [] if not sources or False in sources: expected_transcriptions.append( { "text": "Laulont Francois 8", "worker_version": None, "worker_run": None, } ) if not sources or "id" in sources: expected_transcriptions.append( { "text": "laulont francois 8", "worker_version": "worker_version_id", "worker_run": "worker_run_id", } ) # Do not compare IDs transcriptions = transcriptions.dicts() for transcription in transcriptions: del transcription["id"] assert sorted(transcriptions, key=itemgetter("text")) == expected_transcriptions @pytest.mark.parametrize("source", (False, "id", None)) @pytest.mark.parametrize( "supported_types", (["surname"], ["surname", "firstname", "age"]) ) def test_get_transcription_entities(source, mock_database, supported_types): worker_version = f"worker_version_{source}" if isinstance(source, str) else source worker_run = f"worker_run_{source}" if isinstance(source, str) else source transcription_id = "train-page_1-line_1" + ("source" if source else "") entities = get_transcription_entities( transcription_id=transcription_id, entity_worker_versions=[worker_version] if worker_version is not None else [], entity_worker_runs=[worker_run] if worker_run is not None else [], supported_types=supported_types, ) expected_entities = [ { "name": "Laulont", "type": "surname", "offset": 0, "length": 7, }, { "name": "Francois", "type": "firstname", "offset": 9, "length": 8, }, { "name": "8", "type": "age", "offset": 19, "length": 1, }, ] expected_entities = list( filter(lambda ent: ent["type"] in supported_types, expected_entities) ) for entity in expected_entities: if source: entity["name"] = entity["name"].lower() entity["worker_version"] = worker_version or None entity["worker_run"] = worker_run or None assert sorted(entities, key=itemgetter("offset")) == expected_entities