Skip to content
Snippets Groups Projects
test_get_labels_aligned.py 2.67 KiB
import pytest

from nerval import evaluate

fake_annot_original = "Gérard de Nerval was born in Paris in 1808 ."
fake_predict_original = "G*rard de *N*erval bo*rn in Paris in 1833 *."

fake_annot_aligned = "Gérard de -N-erval was bo-rn in Paris in 1808 -."
fake_predict_aligned = "G*rard de *N*erval ----bo*rn in Paris in 1833 *."

# fmt: off
fake_annot_tags_original = [
    "B-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER",
    "I-PER",
    "I-PER", "I-PER",
    "I-PER",
    "I-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER",
    "O",
    "O", "O", "O",
    "O",
    "O", "O", "O", "O",
    "O",
    "O", "O",
    "O",
    "B-LOC", "I-LOC", "I-LOC", "I-LOC", "I-LOC",
    "O",
    "O", "O",
    "O",
    "B-DAT", "I-DAT", "I-DAT", "I-DAT",
    "O",
    "O",
]

fake_predict_tags_original = [
    "B-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER",
    "I-PER",
    "I-PER", "I-PER",
    "I-PER",
    "I-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER",
    "O",
    "O", "O", "O", "O", "O",
    "O",
    "O", "O",
    "O",
    "***", "***", "***", "***", "***",
    "O",
    "O", "O",
    "O",
    "B-DAT", "I-DAT", "I-DAT", "I-DAT",
    "O",
    "O", "O",
]

expected_annot_tags_aligned = [
    "B-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER",
    "I-PER",
    "I-PER", "I-PER",
    "I-PER",
    "I-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER",
    "O",
    "O", "O", "O",
    "O",
    "O", "O", "O", "O", "O",
    "O",
    "O", "O",
    "O",
    "B-LOC", "I-LOC", "I-LOC", "I-LOC", "I-LOC",
    "O",
    "O", "O",
    "O",
    "B-DAT", "I-DAT", "I-DAT", "I-DAT",
    "O",
    "O", "O",
]

expected_predict_tags_aligned = [
    "B-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER",
    "I-PER",
    "I-PER", "I-PER",
    "I-PER",
    "I-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER", "I-PER",
    "O",
    "O", "O", "O", "O",
    "O", "O", "O", "O", "O",
    "O",
    "O", "O",
    "O",
    "***", "***", "***", "***", "***",
    "O",
    "O", "O",
    "O",
    "B-DAT", "I-DAT", "I-DAT", "I-DAT",
    "O",
    "O", "O",
]
# fmt: on


@pytest.mark.parametrize(
    ("test_input", "expected"),
    [
        (
            (fake_annot_original, fake_annot_aligned, fake_annot_tags_original),
            expected_annot_tags_aligned,
        ),
        (
            (fake_predict_original, fake_predict_aligned, fake_predict_tags_original),
            expected_predict_tags_aligned,
        ),
    ],
)
def test_get_labels_aligned(test_input, expected):
    assert evaluate.get_labels_aligned(*test_input) == expected


def test_get_labels_aligned_empty_entry():
    with pytest.raises(AssertionError):
        evaluate.get_labels_aligned(None, None, None)