Skip to content
Snippets Groups Projects
test_extract.py 18.00 KiB
# -*- coding: utf-8 -*-

import json
import pickle
import re
from operator import methodcaller
from typing import NamedTuple

import pytest

from arkindex_export import (
    DatasetElement,
    Element,
    Transcription,
    TranscriptionEntity,
)
from dan.datasets.extract.arkindex import ArkindexExtractor
from dan.datasets.extract.db import get_transcription_entities
from dan.datasets.extract.exceptions import (
    NoTranscriptionError,
    UnknownTokenInText,
)
from dan.datasets.extract.utils import (
    EntityType,
    entities_to_xml,
    normalize_linebreaks,
    normalize_spaces,
)
from dan.utils import parse_tokens
from tests import FIXTURES

EXTRACTION_DATA_PATH = FIXTURES / "extraction"

TWO_SPACES_REGEX = re.compile(r" {2}")
ENTITY_TOKEN_SPACE = re.compile(r"[ⓢ|ⓕ|ⓑ] ")
TWO_SPACES_LM_REGEX = re.compile(r"▁ ▁")

# NamedTuple to mock actual database result
Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str)

TOKENS = {
    "P": EntityType(start="", end=""),
    "D": EntityType(start="", end=""),
    "N": EntityType(start="", end=""),
    "I": EntityType(start="", end=""),
}


def filter_tokens(keys):
    return {key: value for key, value in TOKENS.items() if key in keys}


@pytest.mark.parametrize(
    "text,trimmed",
    (
        ("no_spaces", "no_spaces"),
        (" beginning", "beginning"),
        ("ending ", "ending"),
        (" both ", "both"),
        ("    consecutive", "consecutive"),
        ("\ttab", "tab"),
        ("\t tab", "tab"),
        (" \ttab", "tab"),
        ("no|space", "no|space"),
    ),
)
def test_normalize_spaces(text, trimmed):
    assert normalize_spaces(text) == trimmed


@pytest.mark.parametrize(
    "text,trimmed",
    (
        ("no_linebreaks", "no_linebreaks"),
        ("\nbeginning", "beginning"),
        ("ending\n", "ending"),
        ("\nboth\n", "both"),
        ("\n\n\nconsecutive", "consecutive"),
        ("\rcarriage_return", "carriage_return"),
        ("\r\ncarriage_return+linebreak", "carriage_return+linebreak"),
        ("\n\r\r\n\ncarriage_return+linebreak", "carriage_return+linebreak"),
        ("no|linebreaks", "no|linebreaks"),
    ),
)
def test_normalize_linebreaks(text, trimmed):
    assert normalize_linebreaks(text) == trimmed


def test_process_element_unknown_token_in_text_error(mock_database, tmp_path):
    output = tmp_path / "extraction"
    arkindex_extractor = ArkindexExtractor(output=output)

    # Retrieve a dataset element and update its transcription with an invalid one
    dataset_element = DatasetElement.select().first()
    element = dataset_element.element
    Transcription.update({Transcription.text: "Is this text valid⁇"}).execute()

    with pytest.raises(
        UnknownTokenInText,
        match=re.escape(
            f"Unknown token found in the transcription text of element ({element.id})"
        ),
    ):
        arkindex_extractor.process_element(dataset_element, element)


@pytest.mark.parametrize(
    "load_entities,keep_spaces,transcription_entities_worker_version,expected_subword_language_corpus,subword_vocab_size",
    (
        (
            True,
            True,
            "worker_version_id",
            """▁ ⓢ c a i l l e t ▁ ⓕ m a u r i c e ▁ ⓑ 28. 9.0 6
▁ ⓢ re b ou l ▁ ⓕ j e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ b a re y re ▁ ⓕ j e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ r ou s s y ▁ ⓕ j e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ m a r i n ▁ ⓕ m a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ a m i c a l ▁ ⓕ e l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ b i r o s ▁ ⓕ m a e l ▁ ⓑ 30. 1 0 . 1 0""",
            40,
        ),
        (
            True,
            False,
            "worker_version_id",
            """▁ ⓢ c a i l l e t ▁ ⓕ m a u r i c e ▁ ⓑ 28. 9.0 6
▁ ⓢ re b ou l ▁ ⓕ j e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ b a re y re ▁ ⓕ j e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ r ou s s y ▁ ⓕ j e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ m a r i n ▁ ⓕ m a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ a m i c a l ▁ ⓕ e l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ b i r o s ▁ ⓕ m a e l ▁ ⓑ 30. 1 0 . 1 0""",
            40,
        ),
        (
            False,
            True,
            "worker_version_id",
            """▁ ca i l l e t ▁ ma u r i ce ▁ 28. 9.0 6
▁ re b o u l ▁ j e a n ▁ 30. 9.0 2
▁ b a re y re ▁ j e a n ▁ 28. 3 . 1 1
▁ r o u s s y ▁ j e a n ▁ 4 . 11.1 4
▁ ma r i n ▁ ma r ce l ▁ 10. 8 . 0 6
▁ a m i ca l ▁ el o i ▁ 11.1 0 . 0 4
▁ b i r o s ▁ ma el ▁ 30. 10. 1 0""",
            40,
        ),
        (
            False,
            False,
            "worker_version_id",
            """▁ ca i l l e t ▁ ma u r i ce ▁ 28. 9.0 6
▁ re b o u l ▁ j e a n ▁ 30. 9.0 2
▁ b a re y re ▁ j e a n ▁ 28. 3 . 1 1
▁ r o u s s y ▁ j e a n ▁ 4 . 11.1 4
▁ ma r i n ▁ ma r ce l ▁ 10. 8 . 0 6
▁ a m i ca l ▁ el o i ▁ 11.1 0 . 0 4
▁ b i r o s ▁ ma el ▁ 30. 10. 1 0""",
            40,
        ),
        (
            True,
            True,
            False,
            """▁ ⓢ C a i l l e t ▁ ⓕ M a u r i c e ▁ ⓑ 2 8 . 9 . 0 6
▁ ⓢ R e b o u l ▁ ⓕ J e a n ▁ ⓑ 3 0 . 9 . 0 2
▁ ⓢ B a r e y r e ▁ ⓕ J e a n ▁ ⓑ 2 8 . 3 . 1 1
▁ ⓢ R o u s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ M a r i n ▁ ⓕ M a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ A m i c a l ▁ ⓕ E l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 3 0 . 1 0 . 1 0""",
            40,
        ),
        (
            True,
            True,
            False,
            """▁ ⓢ C a i l l e t ▁ ⓕ M a u ri ce ▁ ⓑ 28. 9.0 6
▁ ⓢ R e b ou l ▁ ⓕ J e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ B a re y re ▁ ⓕ J e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ R ou s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 11.1 4
▁ ⓢ Mar i n ▁ ⓕ Mar ce l ▁ ⓑ 10. 8 . 0 6
▁ ⓢ A m ic a l ▁ ⓕ E l o i ▁ ⓑ 11.1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 30. 10. 10""",
            55,
        ),
        (
            True,
            False,
            False,
            """▁ ⓢ C a i l l e t ▁ ⓕ M a u r i c e ▁ ⓑ 2 8 . 9 . 0 6
▁ ⓢ R e b o u l ▁ ⓕ J e a n ▁ ⓑ 3 0 . 9 . 0 2
▁ ⓢ B a r e y r e ▁ ⓕ J e a n ▁ ⓑ 2 8 . 3 . 1 1
▁ ⓢ R o u s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ M a r i n ▁ ⓕ M a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ A m i c a l ▁ ⓕ E l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 3 0 . 1 0 . 1 0""",
            40,
        ),
        (
            False,
            True,
            False,
            """▁ C a i l l e t ▁ Ma u r i c e ▁ 28. 9.0 6
▁ R e b o u l ▁ J e a n ▁ 30. 9.0 2
▁ B a r e y r e ▁ J e a n ▁ 28. 3 . 1 1
▁ R o u s s y ▁ J e a n ▁ 4 . 1 1 . 1 4
▁ Ma r i n ▁ Ma r c e l ▁ 1 0 . 8 . 0 6
▁ A m i c a l ▁ E l o i ▁ 1 1 . 1 0 . 0 4
▁ B i r o s ▁ Ma e l ▁ 30. 1 0 . 1 0""",
            40,
        ),
        (
            False,
            False,
            False,
            """▁ C a i l l e t ▁ Ma u r i c e ▁ 28. 9.0 6
▁ R e b o u l ▁ J e a n ▁ 30. 9.0 2
▁ B a r e y r e ▁ J e a n ▁ 28. 3 . 1 1
▁ R o u s s y ▁ J e a n ▁ 4 . 1 1 . 1 4
▁ Ma r i n ▁ Ma r c e l ▁ 1 0 . 8 . 0 6
▁ A m i c a l ▁ E l o i ▁ 1 1 . 1 0 . 0 4
▁ B i r o s ▁ Ma e l ▁ 30. 1 0 . 1 0""",
            40,
        ),
    ),
)
def test_extract(
    load_entities,
    keep_spaces,
    transcription_entities_worker_version,
    split_content,
    mock_database,
    expected_subword_language_corpus,
    subword_vocab_size,
    tmp_path,
):
    output = tmp_path / "extraction"
    output.mkdir(parents=True, exist_ok=True)
    (output / "language_model").mkdir(parents=True, exist_ok=True)
    tokens_path = EXTRACTION_DATA_PATH / "tokens.yml"
    tokens = [
        token
        for entity_type in parse_tokens(tokens_path).values()
        for token in [entity_type.start, entity_type.end]
        if token
    ]

    extractor = ArkindexExtractor(
        dataset_ids=["dataset_id"],
        element_type=["text_line"],
        output=output,
        # Keep the whole text
        entity_separators=None,
        tokens=tokens_path if load_entities else None,
        transcription_worker_versions=[transcription_entities_worker_version],
        entity_worker_versions=[transcription_entities_worker_version]
        if load_entities
        else [],
        keep_spaces=keep_spaces,
        subword_vocab_size=subword_vocab_size,
    )
    extractor.run()

    expected_paths = [
        output / "charset.pkl",
        # Language resources
        output / "language_model" / "corpus_characters.txt",
        output / "language_model" / "corpus_subwords.txt",
        output / "language_model" / "corpus_words.txt",
        output / "language_model" / "lexicon_characters.txt",
        output / "language_model" / "lexicon_subwords.txt",
        output / "language_model" / "lexicon_words.txt",
        output / "language_model" / "subword_tokenizer.model",
        output / "language_model" / "subword_tokenizer.vocab",
        output / "language_model" / "tokens.txt",
        output / "split.json",
    ]
    assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths

    # Check "split.json"
    # Transcriptions with worker version are in lowercase
    if transcription_entities_worker_version:
        for split in split_content:
            for element_id in split_content[split]:
                split_content[split][element_id]["text"] = split_content[split][
                    element_id
                ]["text"].lower()

    # If we do not load entities, remove tokens
    if not load_entities:
        token_translations = {ord(token): None for token in tokens}
        for split in split_content:
            for element_id in split_content[split]:
                split_content[split][element_id]["text"] = split_content[split][
                    element_id
                ]["text"].translate(token_translations)

    # Replace double spaces with regular space
    if not keep_spaces:
        for split in split_content:
            for element_id in split_content[split]:
                split_content[split][element_id]["text"] = TWO_SPACES_REGEX.sub(
                    " ", split_content[split][element_id]["text"]
                )

    assert json.loads((output / "split.json").read_text()) == split_content

    # Check "charset.pkl"
    expected_charset = set()
    for values in split_content["train"].values():
        expected_charset.update(set(values["text"]))

    if load_entities:
        expected_charset.update(tokens)
    expected_charset.add("")
    assert set(pickle.loads((output / "charset.pkl").read_bytes())) == expected_charset

    # Check "language_corpus.txt"
    expected_char_language_corpus = """ⓢ C a i l l e t ▁ ▁ ⓕ M a u r i c e ▁ ▁ ⓑ 2 8 . 9 . 0 6
ⓢ R e b o u l ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 3 0 . 9 . 0 2
ⓢ B a r e y r e ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 2 8 . 3 . 1 1
ⓢ R o u s s y ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 4 . 1 1 . 1 4
ⓢ M a r i n ▁ ▁ ⓕ M a r c e l ▁ ▁ ⓑ 1 0 . 8 . 0 6
ⓢ A m i c a l ▁ ▁ ⓕ E l o i ▁ ▁ ⓑ 1 1 . 1 0 . 0 4
ⓢ B i r o s ▁ ▁ ⓕ M a e l ▁ ▁ ⓑ 3 0 . 1 0 . 1 0"""

    expected_word_language_corpus = """ⓢ Caillet ▁ ⓕ Maurice ▁ ⓑ 28 ▁ . ▁ 9 ▁ . ▁ 06
ⓢ Reboul ▁ ⓕ Jean ▁ ⓑ 30 ▁ . ▁ 9 ▁ . ▁ 02
ⓢ Bareyre ▁ ⓕ Jean ▁ ⓑ 28 ▁ . ▁ 3 ▁ . ▁ 11
ⓢ Roussy ▁ ⓕ Jean ▁ ⓑ 4 ▁ . ▁ 11 ▁ . ▁ 14
ⓢ Marin ▁ ⓕ Marcel ▁ ⓑ 10 ▁ . ▁ 8 ▁ . ▁ 06
ⓢ Amical ▁ ⓕ Eloi ▁ ⓑ 11 ▁ . ▁ 10 ▁ . ▁ 04
ⓢ Biros ▁ ⓕ Mael ▁ ⓑ 30 ▁ . ▁ 10 ▁ . ▁ 10"""

    # Transcriptions with worker version are in lowercase
    if transcription_entities_worker_version:
        expected_char_language_corpus = expected_char_language_corpus.lower()
        expected_word_language_corpus = expected_word_language_corpus.lower()
        expected_subword_language_corpus = expected_subword_language_corpus.lower()

    # If we do not load entities, remove tokens
    if not load_entities:
        token_translations = {f"{token} ": "" for token in tokens}
        expected_char_language_corpus = ENTITY_TOKEN_SPACE.sub(
            "", expected_char_language_corpus
        )
        expected_word_language_corpus = ENTITY_TOKEN_SPACE.sub(
            "", expected_word_language_corpus
        )
        expected_subword_language_corpus = ENTITY_TOKEN_SPACE.sub(
            "", expected_subword_language_corpus
        )
    # Replace double spaces with regular space
    if not keep_spaces:
        expected_char_language_corpus = TWO_SPACES_LM_REGEX.sub(
            "", expected_char_language_corpus
        )
        expected_word_language_corpus = TWO_SPACES_LM_REGEX.sub(
            "", expected_word_language_corpus
        )
        expected_subword_language_corpus = TWO_SPACES_LM_REGEX.sub(
            "", expected_subword_language_corpus
        )

    assert (
        output / "language_model" / "corpus_characters.txt"
    ).read_text() == expected_char_language_corpus

    assert (
        output / "language_model" / "corpus_words.txt"
    ).read_text() == expected_word_language_corpus

    assert (
        output / "language_model" / "corpus_subwords.txt"
    ).read_text() == expected_subword_language_corpus

    # Check "language_tokens.txt"
    expected_language_tokens = [
        "" if t.isspace() else t for t in sorted(list(expected_charset))
    ]
    expected_language_tokens.append("")
    assert (output / "language_model" / "tokens.txt").read_text() == "\n".join(
        expected_language_tokens
    )

    # Check "language_lexicon.txt"
    expected_language_char_lexicon = [f"{t} {t}" for t in expected_language_tokens]
    assert (
        output / "language_model" / "lexicon_characters.txt"
    ).read_text() == "\n".join(expected_language_char_lexicon)

    word_vocab = set([word for word in expected_word_language_corpus.split()])
    expected_language_word_lexicon = [
        f"{word} {' '.join(word)}" for word in sorted(word_vocab)
    ]
    assert (output / "language_model" / "lexicon_words.txt").read_text() == "\n".join(
        expected_language_word_lexicon
    )

    subword_vocab = set(
        [subword for subword in expected_subword_language_corpus.split()]
    )
    expected_language_subword_lexicon = [
        f"{subword} {' '.join(subword)}" for subword in sorted(subword_vocab)
    ]
    assert (
        output / "language_model" / "lexicon_subwords.txt"
    ).read_text() == "\n".join(expected_language_subword_lexicon)


@pytest.mark.parametrize("allow_empty", (True, False))
def test_empty_transcription(allow_empty, mock_database):
    extractor = ArkindexExtractor(
        element_type=["text_line"],
        entity_separators=None,
        allow_empty=allow_empty,
    )
    element_no_transcription = Element(id="unknown")
    if allow_empty:
        assert extractor.extract_transcription(element_no_transcription) == ""
    else:
        with pytest.raises(NoTranscriptionError):
            extractor.extract_transcription(element_no_transcription)


@pytest.mark.parametrize(
    "nestation, xml_output, separators",
    (
        # Non-nested
        (
            "non-nested-id",
            "<root>The <adj>great</adj> king <name>Charles</name> III has eaten \nwith <person>us</person>.</root>",
            None,
        ),
        # Non-nested no text between entities
        (
            "non-nested-id",
            "<root><adj>great</adj> <name>Charles</name>\n<person>us</person></root>",
            ["\n", " "],
        ),
        # Nested
        (
            "nested-id",
            "<root>The great king <fullname><name>Charles</name> III</fullname> has eaten \nwith <person>us</person>.</root>",
            None,
        ),
        # Nested no text between entities
        (
            "nested-id",
            "<root><fullname><name>Charles</name> III</fullname>\n<person>us</person></root>",
            ["\n", " "],
        ),
    ),
)
def test_entities_to_xml(mock_database, nestation, xml_output, separators):
    transcription = Transcription.get_by_id("tr-with-entities")
    assert (
        entities_to_xml(
            text=transcription.text,
            predictions=get_transcription_entities(
                transcription_id="tr-with-entities",
                entity_worker_versions=[nestation],
                supported_types=["name", "fullname", "person", "adj"],
            ),
            entity_separators=separators,
        )
        == xml_output
    )


@pytest.mark.parametrize(
    "supported_entities, xml_output, separators",
    (
        # <adj> missing, no text between entities
        (
            ["name", "person"],
            "<root><name>Charles</name>\n<person>us</person></root>",
            ["\n", " "],
        ),
        # <adj> missing, text between entities
        (
            ["name", "person"],
            "<root>The great king <name>Charles</name> III has eaten \nwith <person>us</person>.</root>",
            None,
        ),
    ),
)
def test_entities_to_xml_partial_entities(
    mock_database, supported_entities, xml_output, separators
):
    transcription = Transcription.get_by_id("tr-with-entities")
    assert (
        entities_to_xml(
            text=transcription.text,
            predictions=get_transcription_entities(
                transcription_id="tr-with-entities",
                entity_worker_versions=["non-nested-id"],
                supported_types=supported_entities,
            ),
            entity_separators=separators,
        )
        == xml_output
    )


@pytest.mark.parametrize(
    "transcription",
    (
        "Something\n",
        "Something\r",
        "Something\t",
        'Something"',
        "Something'",
        "Something<",
        "Something>",
        "Something&",
    ),
)
def test_entities_to_xml_no_encode(transcription):
    assert (
        entities_to_xml(
            text=transcription,
            # Empty queryset
            predictions=TranscriptionEntity.select().where(TranscriptionEntity.id == 0),
            entity_separators=None,
        )
        == f"<root>{transcription}</root>"
    )