Something went wrong on our end
-
Manon Blanco authoredManon Blanco authored
test_extract.py 18.00 KiB
# -*- coding: utf-8 -*-
import json
import pickle
import re
from operator import methodcaller
from typing import NamedTuple
import pytest
from arkindex_export import (
DatasetElement,
Element,
Transcription,
TranscriptionEntity,
)
from dan.datasets.extract.arkindex import ArkindexExtractor
from dan.datasets.extract.db import get_transcription_entities
from dan.datasets.extract.exceptions import (
NoTranscriptionError,
UnknownTokenInText,
)
from dan.datasets.extract.utils import (
EntityType,
entities_to_xml,
normalize_linebreaks,
normalize_spaces,
)
from dan.utils import parse_tokens
from tests import FIXTURES
EXTRACTION_DATA_PATH = FIXTURES / "extraction"
TWO_SPACES_REGEX = re.compile(r" {2}")
ENTITY_TOKEN_SPACE = re.compile(r"[ⓢ|ⓕ|ⓑ] ")
TWO_SPACES_LM_REGEX = re.compile(r"▁ ▁")
# NamedTuple to mock actual database result
Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str)
TOKENS = {
"P": EntityType(start="ⓟ", end="Ⓟ"),
"D": EntityType(start="ⓓ", end="Ⓓ"),
"N": EntityType(start="ⓝ", end="Ⓝ"),
"I": EntityType(start="ⓘ", end="Ⓘ"),
}
def filter_tokens(keys):
return {key: value for key, value in TOKENS.items() if key in keys}
@pytest.mark.parametrize(
"text,trimmed",
(
("no_spaces", "no_spaces"),
(" beginning", "beginning"),
("ending ", "ending"),
(" both ", "both"),
(" consecutive", "consecutive"),
("\ttab", "tab"),
("\t tab", "tab"),
(" \ttab", "tab"),
("no|space", "no|space"),
),
)
def test_normalize_spaces(text, trimmed):
assert normalize_spaces(text) == trimmed
@pytest.mark.parametrize(
"text,trimmed",
(
("no_linebreaks", "no_linebreaks"),
("\nbeginning", "beginning"),
("ending\n", "ending"),
("\nboth\n", "both"),
("\n\n\nconsecutive", "consecutive"),
("\rcarriage_return", "carriage_return"),
("\r\ncarriage_return+linebreak", "carriage_return+linebreak"),
("\n\r\r\n\ncarriage_return+linebreak", "carriage_return+linebreak"),
("no|linebreaks", "no|linebreaks"),
),
)
def test_normalize_linebreaks(text, trimmed):
assert normalize_linebreaks(text) == trimmed
def test_process_element_unknown_token_in_text_error(mock_database, tmp_path):
output = tmp_path / "extraction"
arkindex_extractor = ArkindexExtractor(output=output)
# Retrieve a dataset element and update its transcription with an invalid one
dataset_element = DatasetElement.select().first()
element = dataset_element.element
Transcription.update({Transcription.text: "Is this text valid⁇"}).execute()
with pytest.raises(
UnknownTokenInText,
match=re.escape(
f"Unknown token found in the transcription text of element ({element.id})"
),
):
arkindex_extractor.process_element(dataset_element, element)
@pytest.mark.parametrize(
"load_entities,keep_spaces,transcription_entities_worker_version,expected_subword_language_corpus,subword_vocab_size",
(
(
True,
True,
"worker_version_id",
"""▁ ⓢ c a i l l e t ▁ ⓕ m a u r i c e ▁ ⓑ 28. 9.0 6
▁ ⓢ re b ou l ▁ ⓕ j e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ b a re y re ▁ ⓕ j e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ r ou s s y ▁ ⓕ j e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ m a r i n ▁ ⓕ m a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ a m i c a l ▁ ⓕ e l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ b i r o s ▁ ⓕ m a e l ▁ ⓑ 30. 1 0 . 1 0""",
40,
),
(
True,
False,
"worker_version_id",
"""▁ ⓢ c a i l l e t ▁ ⓕ m a u r i c e ▁ ⓑ 28. 9.0 6
▁ ⓢ re b ou l ▁ ⓕ j e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ b a re y re ▁ ⓕ j e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ r ou s s y ▁ ⓕ j e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ m a r i n ▁ ⓕ m a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ a m i c a l ▁ ⓕ e l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ b i r o s ▁ ⓕ m a e l ▁ ⓑ 30. 1 0 . 1 0""",
40,
),
(
False,
True,
"worker_version_id",
"""▁ ca i l l e t ▁ ma u r i ce ▁ 28. 9.0 6
▁ re b o u l ▁ j e a n ▁ 30. 9.0 2
▁ b a re y re ▁ j e a n ▁ 28. 3 . 1 1
▁ r o u s s y ▁ j e a n ▁ 4 . 11.1 4
▁ ma r i n ▁ ma r ce l ▁ 10. 8 . 0 6
▁ a m i ca l ▁ el o i ▁ 11.1 0 . 0 4
▁ b i r o s ▁ ma el ▁ 30. 10. 1 0""",
40,
),
(
False,
False,
"worker_version_id",
"""▁ ca i l l e t ▁ ma u r i ce ▁ 28. 9.0 6
▁ re b o u l ▁ j e a n ▁ 30. 9.0 2
▁ b a re y re ▁ j e a n ▁ 28. 3 . 1 1
▁ r o u s s y ▁ j e a n ▁ 4 . 11.1 4
▁ ma r i n ▁ ma r ce l ▁ 10. 8 . 0 6
▁ a m i ca l ▁ el o i ▁ 11.1 0 . 0 4
▁ b i r o s ▁ ma el ▁ 30. 10. 1 0""",
40,
),
(
True,
True,
False,
"""▁ ⓢ C a i l l e t ▁ ⓕ M a u r i c e ▁ ⓑ 2 8 . 9 . 0 6
▁ ⓢ R e b o u l ▁ ⓕ J e a n ▁ ⓑ 3 0 . 9 . 0 2
▁ ⓢ B a r e y r e ▁ ⓕ J e a n ▁ ⓑ 2 8 . 3 . 1 1
▁ ⓢ R o u s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ M a r i n ▁ ⓕ M a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ A m i c a l ▁ ⓕ E l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 3 0 . 1 0 . 1 0""",
40,
),
(
True,
True,
False,
"""▁ ⓢ C a i l l e t ▁ ⓕ M a u ri ce ▁ ⓑ 28. 9.0 6
▁ ⓢ R e b ou l ▁ ⓕ J e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ B a re y re ▁ ⓕ J e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ R ou s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 11.1 4
▁ ⓢ Mar i n ▁ ⓕ Mar ce l ▁ ⓑ 10. 8 . 0 6
▁ ⓢ A m ic a l ▁ ⓕ E l o i ▁ ⓑ 11.1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 30. 10. 10""",
55,
),
(
True,
False,
False,
"""▁ ⓢ C a i l l e t ▁ ⓕ M a u r i c e ▁ ⓑ 2 8 . 9 . 0 6
▁ ⓢ R e b o u l ▁ ⓕ J e a n ▁ ⓑ 3 0 . 9 . 0 2
▁ ⓢ B a r e y r e ▁ ⓕ J e a n ▁ ⓑ 2 8 . 3 . 1 1
▁ ⓢ R o u s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ M a r i n ▁ ⓕ M a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ A m i c a l ▁ ⓕ E l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 3 0 . 1 0 . 1 0""",
40,
),
(
False,
True,
False,
"""▁ C a i l l e t ▁ Ma u r i c e ▁ 28. 9.0 6
▁ R e b o u l ▁ J e a n ▁ 30. 9.0 2
▁ B a r e y r e ▁ J e a n ▁ 28. 3 . 1 1
▁ R o u s s y ▁ J e a n ▁ 4 . 1 1 . 1 4
▁ Ma r i n ▁ Ma r c e l ▁ 1 0 . 8 . 0 6
▁ A m i c a l ▁ E l o i ▁ 1 1 . 1 0 . 0 4
▁ B i r o s ▁ Ma e l ▁ 30. 1 0 . 1 0""",
40,
),
(
False,
False,
False,
"""▁ C a i l l e t ▁ Ma u r i c e ▁ 28. 9.0 6
▁ R e b o u l ▁ J e a n ▁ 30. 9.0 2
▁ B a r e y r e ▁ J e a n ▁ 28. 3 . 1 1
▁ R o u s s y ▁ J e a n ▁ 4 . 1 1 . 1 4
▁ Ma r i n ▁ Ma r c e l ▁ 1 0 . 8 . 0 6
▁ A m i c a l ▁ E l o i ▁ 1 1 . 1 0 . 0 4
▁ B i r o s ▁ Ma e l ▁ 30. 1 0 . 1 0""",
40,
),
),
)
def test_extract(
load_entities,
keep_spaces,
transcription_entities_worker_version,
split_content,
mock_database,
expected_subword_language_corpus,
subword_vocab_size,
tmp_path,
):
output = tmp_path / "extraction"
output.mkdir(parents=True, exist_ok=True)
(output / "language_model").mkdir(parents=True, exist_ok=True)
tokens_path = EXTRACTION_DATA_PATH / "tokens.yml"
tokens = [
token
for entity_type in parse_tokens(tokens_path).values()
for token in [entity_type.start, entity_type.end]
if token
]
extractor = ArkindexExtractor(
dataset_ids=["dataset_id"],
element_type=["text_line"],
output=output,
# Keep the whole text
entity_separators=None,
tokens=tokens_path if load_entities else None,
transcription_worker_versions=[transcription_entities_worker_version],
entity_worker_versions=[transcription_entities_worker_version]
if load_entities
else [],
keep_spaces=keep_spaces,
subword_vocab_size=subword_vocab_size,
)
extractor.run()
expected_paths = [
output / "charset.pkl",
# Language resources
output / "language_model" / "corpus_characters.txt",
output / "language_model" / "corpus_subwords.txt",
output / "language_model" / "corpus_words.txt",
output / "language_model" / "lexicon_characters.txt",
output / "language_model" / "lexicon_subwords.txt",
output / "language_model" / "lexicon_words.txt",
output / "language_model" / "subword_tokenizer.model",
output / "language_model" / "subword_tokenizer.vocab",
output / "language_model" / "tokens.txt",
output / "split.json",
]
assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths
# Check "split.json"
# Transcriptions with worker version are in lowercase
if transcription_entities_worker_version:
for split in split_content:
for element_id in split_content[split]:
split_content[split][element_id]["text"] = split_content[split][
element_id
]["text"].lower()
# If we do not load entities, remove tokens
if not load_entities:
token_translations = {ord(token): None for token in tokens}
for split in split_content:
for element_id in split_content[split]:
split_content[split][element_id]["text"] = split_content[split][
element_id
]["text"].translate(token_translations)
# Replace double spaces with regular space
if not keep_spaces:
for split in split_content:
for element_id in split_content[split]:
split_content[split][element_id]["text"] = TWO_SPACES_REGEX.sub(
" ", split_content[split][element_id]["text"]
)
assert json.loads((output / "split.json").read_text()) == split_content
# Check "charset.pkl"
expected_charset = set()
for values in split_content["train"].values():
expected_charset.update(set(values["text"]))
if load_entities:
expected_charset.update(tokens)
expected_charset.add("⁇")
assert set(pickle.loads((output / "charset.pkl").read_bytes())) == expected_charset
# Check "language_corpus.txt"
expected_char_language_corpus = """ⓢ C a i l l e t ▁ ▁ ⓕ M a u r i c e ▁ ▁ ⓑ 2 8 . 9 . 0 6
ⓢ R e b o u l ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 3 0 . 9 . 0 2
ⓢ B a r e y r e ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 2 8 . 3 . 1 1
ⓢ R o u s s y ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 4 . 1 1 . 1 4
ⓢ M a r i n ▁ ▁ ⓕ M a r c e l ▁ ▁ ⓑ 1 0 . 8 . 0 6
ⓢ A m i c a l ▁ ▁ ⓕ E l o i ▁ ▁ ⓑ 1 1 . 1 0 . 0 4
ⓢ B i r o s ▁ ▁ ⓕ M a e l ▁ ▁ ⓑ 3 0 . 1 0 . 1 0"""
expected_word_language_corpus = """ⓢ Caillet ▁ ⓕ Maurice ▁ ⓑ 28 ▁ . ▁ 9 ▁ . ▁ 06
ⓢ Reboul ▁ ⓕ Jean ▁ ⓑ 30 ▁ . ▁ 9 ▁ . ▁ 02
ⓢ Bareyre ▁ ⓕ Jean ▁ ⓑ 28 ▁ . ▁ 3 ▁ . ▁ 11
ⓢ Roussy ▁ ⓕ Jean ▁ ⓑ 4 ▁ . ▁ 11 ▁ . ▁ 14
ⓢ Marin ▁ ⓕ Marcel ▁ ⓑ 10 ▁ . ▁ 8 ▁ . ▁ 06
ⓢ Amical ▁ ⓕ Eloi ▁ ⓑ 11 ▁ . ▁ 10 ▁ . ▁ 04
ⓢ Biros ▁ ⓕ Mael ▁ ⓑ 30 ▁ . ▁ 10 ▁ . ▁ 10"""
# Transcriptions with worker version are in lowercase
if transcription_entities_worker_version:
expected_char_language_corpus = expected_char_language_corpus.lower()
expected_word_language_corpus = expected_word_language_corpus.lower()
expected_subword_language_corpus = expected_subword_language_corpus.lower()
# If we do not load entities, remove tokens
if not load_entities:
token_translations = {f"{token} ": "" for token in tokens}
expected_char_language_corpus = ENTITY_TOKEN_SPACE.sub(
"", expected_char_language_corpus
)
expected_word_language_corpus = ENTITY_TOKEN_SPACE.sub(
"", expected_word_language_corpus
)
expected_subword_language_corpus = ENTITY_TOKEN_SPACE.sub(
"", expected_subword_language_corpus
)
# Replace double spaces with regular space
if not keep_spaces:
expected_char_language_corpus = TWO_SPACES_LM_REGEX.sub(
"▁", expected_char_language_corpus
)
expected_word_language_corpus = TWO_SPACES_LM_REGEX.sub(
"▁", expected_word_language_corpus
)
expected_subword_language_corpus = TWO_SPACES_LM_REGEX.sub(
"▁", expected_subword_language_corpus
)
assert (
output / "language_model" / "corpus_characters.txt"
).read_text() == expected_char_language_corpus
assert (
output / "language_model" / "corpus_words.txt"
).read_text() == expected_word_language_corpus
assert (
output / "language_model" / "corpus_subwords.txt"
).read_text() == expected_subword_language_corpus
# Check "language_tokens.txt"
expected_language_tokens = [
"▁" if t.isspace() else t for t in sorted(list(expected_charset))
]
expected_language_tokens.append("◌")
assert (output / "language_model" / "tokens.txt").read_text() == "\n".join(
expected_language_tokens
)
# Check "language_lexicon.txt"
expected_language_char_lexicon = [f"{t} {t}" for t in expected_language_tokens]
assert (
output / "language_model" / "lexicon_characters.txt"
).read_text() == "\n".join(expected_language_char_lexicon)
word_vocab = set([word for word in expected_word_language_corpus.split()])
expected_language_word_lexicon = [
f"{word} {' '.join(word)}" for word in sorted(word_vocab)
]
assert (output / "language_model" / "lexicon_words.txt").read_text() == "\n".join(
expected_language_word_lexicon
)
subword_vocab = set(
[subword for subword in expected_subword_language_corpus.split()]
)
expected_language_subword_lexicon = [
f"{subword} {' '.join(subword)}" for subword in sorted(subword_vocab)
]
assert (
output / "language_model" / "lexicon_subwords.txt"
).read_text() == "\n".join(expected_language_subword_lexicon)
@pytest.mark.parametrize("allow_empty", (True, False))
def test_empty_transcription(allow_empty, mock_database):
extractor = ArkindexExtractor(
element_type=["text_line"],
entity_separators=None,
allow_empty=allow_empty,
)
element_no_transcription = Element(id="unknown")
if allow_empty:
assert extractor.extract_transcription(element_no_transcription) == ""
else:
with pytest.raises(NoTranscriptionError):
extractor.extract_transcription(element_no_transcription)
@pytest.mark.parametrize(
"nestation, xml_output, separators",
(
# Non-nested
(
"non-nested-id",
"<root>The <adj>great</adj> king <name>Charles</name> III has eaten \nwith <person>us</person>.</root>",
None,
),
# Non-nested no text between entities
(
"non-nested-id",
"<root><adj>great</adj> <name>Charles</name>\n<person>us</person></root>",
["\n", " "],
),
# Nested
(
"nested-id",
"<root>The great king <fullname><name>Charles</name> III</fullname> has eaten \nwith <person>us</person>.</root>",
None,
),
# Nested no text between entities
(
"nested-id",
"<root><fullname><name>Charles</name> III</fullname>\n<person>us</person></root>",
["\n", " "],
),
),
)
def test_entities_to_xml(mock_database, nestation, xml_output, separators):
transcription = Transcription.get_by_id("tr-with-entities")
assert (
entities_to_xml(
text=transcription.text,
predictions=get_transcription_entities(
transcription_id="tr-with-entities",
entity_worker_versions=[nestation],
supported_types=["name", "fullname", "person", "adj"],
),
entity_separators=separators,
)
== xml_output
)
@pytest.mark.parametrize(
"supported_entities, xml_output, separators",
(
# <adj> missing, no text between entities
(
["name", "person"],
"<root><name>Charles</name>\n<person>us</person></root>",
["\n", " "],
),
# <adj> missing, text between entities
(
["name", "person"],
"<root>The great king <name>Charles</name> III has eaten \nwith <person>us</person>.</root>",
None,
),
),
)
def test_entities_to_xml_partial_entities(
mock_database, supported_entities, xml_output, separators
):
transcription = Transcription.get_by_id("tr-with-entities")
assert (
entities_to_xml(
text=transcription.text,
predictions=get_transcription_entities(
transcription_id="tr-with-entities",
entity_worker_versions=["non-nested-id"],
supported_types=supported_entities,
),
entity_separators=separators,
)
== xml_output
)
@pytest.mark.parametrize(
"transcription",
(
"Something\n",
"Something\r",
"Something\t",
'Something"',
"Something'",
"Something<",
"Something>",
"Something&",
),
)
def test_entities_to_xml_no_encode(transcription):
assert (
entities_to_xml(
text=transcription,
# Empty queryset
predictions=TranscriptionEntity.select().where(TranscriptionEntity.id == 0),
entity_separators=None,
)
== f"<root>{transcription}</root>"
)