Something went wrong on our end
-
Manon Blanco authoredManon Blanco authored
test_db.py 4.05 KiB
# -*- coding: utf-8 -*-
from operator import itemgetter
import pytest
from arkindex_export import Dataset, DatasetElement, Element
from dan.datasets.extract.arkindex import TRAIN_NAME
from dan.datasets.extract.db import (
get_dataset_elements,
get_elements,
get_transcription_entities,
get_transcriptions,
)
def test_get_dataset_elements(mock_database):
"""
Assert dataset elements retrieval output against verified results
"""
dataset_elements = get_dataset_elements(
dataset=Dataset.select().get(),
split=TRAIN_NAME,
)
# ID verification
assert all(
isinstance(dataset_element, DatasetElement)
for dataset_element in dataset_elements
)
assert [dataset_element.element.id for dataset_element in dataset_elements] == [
"train-page_1",
"train-page_2",
]
def test_get_elements(mock_database):
"""
Assert elements retrieval output against verified results
"""
elements = get_elements(
parent_id="train-page_1",
element_type=["text_line"],
)
# ID verification
assert all(isinstance(element, Element) for element in elements)
assert [element.id for element in elements] == [
"train-page_1-line_1",
"train-page_1-line_2",
"train-page_1-line_3",
"train-page_1-line_4",
]
@pytest.mark.parametrize(
"worker_versions",
([False], ["worker_version_id"], [], [False, "worker_version_id"]),
)
def test_get_transcriptions(worker_versions, mock_database):
"""
Assert transcriptions retrieval output against verified results
"""
element_id = "train-page_1-line_1"
transcriptions = get_transcriptions(
element_id=element_id,
transcription_worker_versions=worker_versions,
)
expected_transcriptions = []
if not worker_versions or False in worker_versions:
expected_transcriptions.append(
{
"text": "Caillet Maurice 28.9.06",
"worker_version_id": None,
}
)
if not worker_versions or "worker_version_id" in worker_versions:
expected_transcriptions.append(
{
"text": "caillet maurice 28.9.06",
"worker_version_id": "worker_version_id",
}
)
assert (
sorted(
[
{
"text": transcription.text,
"worker_version_id": transcription.worker_version.id
if transcription.worker_version
else None,
}
for transcription in transcriptions
],
key=itemgetter("text"),
)
== expected_transcriptions
)
@pytest.mark.parametrize("worker_version", (False, "worker_version_id", None))
@pytest.mark.parametrize(
"supported_types", (["surname"], ["surname", "firstname", "birthdate"])
)
def test_get_transcription_entities(worker_version, mock_database, supported_types):
transcription_id = "train-page_1-line_1" + (worker_version or "")
entities = get_transcription_entities(
transcription_id=transcription_id,
entity_worker_versions=[worker_version],
supported_types=supported_types,
)
expected_entities = [
{
"name": "Caillet",
"type": "surname",
"offset": 0,
"length": 7,
},
{
"name": "Maurice",
"type": "firstname",
"offset": 9,
"length": 7,
},
{
"name": "28.9.06",
"type": "birthdate",
"offset": 18,
"length": 7,
},
]
expected_entities = list(
filter(lambda ent: ent["type"] in supported_types, expected_entities)
)
for entity in expected_entities:
if worker_version:
entity["name"] = entity["name"].lower()
entity["worker_version"] = worker_version or None
assert (
sorted(
entities,
key=itemgetter("offset"),
)
== expected_entities
)