Skip to content
Snippets Groups Projects
test_compute_scores.py 1.22 KiB
# -*- coding: utf-8 -*-
import pytest

from nerval import evaluate

fake_annot_entity_count = {"All": 3, "DAT": 1, "LOC": 1, "PER": 1}
fake_predict_entity_count = {"All": 3, "DAT": 1, "***": 1, "PER": 1}
fake_matches = {"All": 1, "PER": 1, "LOC": 0, "DAT": 0}

expected_scores = {
    "***": {
        "P": 0.0,
        "R": None,
        "F1": None,
        "predicted": 1,
        "matched": 0,
        "Support": None,
    },
    "DAT": {"P": 0.0, "R": 0.0, "F1": 0, "predicted": 1, "matched": 0, "Support": 1},
    "All": {
        "P": 0.3333333333333333,
        "R": 0.3333333333333333,
        "F1": 0.3333333333333333,
        "predicted": 3,
        "matched": 1,
        "Support": 3,
    },
    "PER": {"P": 1.0, "R": 1.0, "F1": 1.0, "predicted": 1, "matched": 1, "Support": 1},
    "LOC": {
        "P": None,
        "R": 0.0,
        "F1": None,
        "predicted": None,
        "matched": 0,
        "Support": 1,
    },
}


@pytest.mark.parametrize(
    "test_input, expected",
    [
        (
            (fake_annot_entity_count, fake_predict_entity_count, fake_matches),
            expected_scores,
        )
    ],
)
def test_compute_scores(test_input, expected):
    assert evaluate.compute_scores(*test_input) == expected