Skip to content
Snippets Groups Projects
test_db.py 4.05 KiB
# -*- coding: utf-8 -*-

from operator import itemgetter

import pytest

from arkindex_export import Dataset, DatasetElement, Element
from dan.datasets.extract.arkindex import TRAIN_NAME
from dan.datasets.extract.db import (
    get_dataset_elements,
    get_elements,
    get_transcription_entities,
    get_transcriptions,
)


def test_get_dataset_elements(mock_database):
    """
    Assert dataset elements retrieval output against verified results
    """
    dataset_elements = get_dataset_elements(
        dataset=Dataset.select().get(),
        split=TRAIN_NAME,
    )

    # ID verification
    assert all(
        isinstance(dataset_element, DatasetElement)
        for dataset_element in dataset_elements
    )
    assert [dataset_element.element.id for dataset_element in dataset_elements] == [
        "train-page_1",
        "train-page_2",
    ]


def test_get_elements(mock_database):
    """
    Assert elements retrieval output against verified results
    """
    elements = get_elements(
        parent_id="train-page_1",
        element_type=["text_line"],
    )

    # ID verification
    assert all(isinstance(element, Element) for element in elements)
    assert [element.id for element in elements] == [
        "train-page_1-line_1",
        "train-page_1-line_2",
        "train-page_1-line_3",
        "train-page_1-line_4",
    ]


@pytest.mark.parametrize(
    "worker_versions",
    ([False], ["worker_version_id"], [], [False, "worker_version_id"]),
)
def test_get_transcriptions(worker_versions, mock_database):
    """
    Assert transcriptions retrieval output against verified results
    """
    element_id = "train-page_1-line_1"
    transcriptions = get_transcriptions(
        element_id=element_id,
        transcription_worker_versions=worker_versions,
    )

    expected_transcriptions = []
    if not worker_versions or False in worker_versions:
        expected_transcriptions.append(
            {
                "text": "Caillet  Maurice  28.9.06",
                "worker_version_id": None,
            }
        )

    if not worker_versions or "worker_version_id" in worker_versions:
        expected_transcriptions.append(
            {
                "text": "caillet  maurice  28.9.06",
                "worker_version_id": "worker_version_id",
            }
        )

    assert (
        sorted(
            [
                {
                    "text": transcription.text,
                    "worker_version_id": transcription.worker_version.id
                    if transcription.worker_version
                    else None,
                }
                for transcription in transcriptions
            ],
            key=itemgetter("text"),
        )
        == expected_transcriptions
    )


@pytest.mark.parametrize("worker_version", (False, "worker_version_id", None))
@pytest.mark.parametrize(
    "supported_types", (["surname"], ["surname", "firstname", "birthdate"])
)
def test_get_transcription_entities(worker_version, mock_database, supported_types):
    transcription_id = "train-page_1-line_1" + (worker_version or "")
    entities = get_transcription_entities(
        transcription_id=transcription_id,
        entity_worker_versions=[worker_version],
        supported_types=supported_types,
    )

    expected_entities = [
        {
            "name": "Caillet",
            "type": "surname",
            "offset": 0,
            "length": 7,
        },
        {
            "name": "Maurice",
            "type": "firstname",
            "offset": 9,
            "length": 7,
        },
        {
            "name": "28.9.06",
            "type": "birthdate",
            "offset": 18,
            "length": 7,
        },
    ]

    expected_entities = list(
        filter(lambda ent: ent["type"] in supported_types, expected_entities)
    )
    for entity in expected_entities:
        if worker_version:
            entity["name"] = entity["name"].lower()
        entity["worker_version"] = worker_version or None

    assert (
        sorted(
            entities,
            key=itemgetter("offset"),
        )
        == expected_entities
    )