Skip to content
Snippets Groups Projects
Commit 9a2b1af8 authored by Blanche Miret's avatar Blanche Miret
Browse files

Add threshold as argument

parent b2db7c00
No related branches found
No related tags found
1 merge request!6Add threshold as option
Pipeline #103797 passed
......@@ -9,7 +9,6 @@ import editdistance
import edlib
import termtables as tt
THRESHOLD = 0.30
NOT_ENTITY_TAG = "O"
......@@ -189,7 +188,11 @@ def look_for_further_entity_part(index, tag, characters, labels):
def compute_matches(
annotation: str, prediction: str, labels_annot: list, labels_predict: list
annotation: str,
prediction: str,
labels_annot: list,
labels_predict: list,
threshold: int,
) -> dict:
"""Compute prediction score from annotation string to prediction string.
......@@ -324,7 +327,7 @@ def compute_matches(
score = (
1
if editdistance.eval(entity_ref, entity_compar) / len_entity
< THRESHOLD
<= threshold
else 0
)
entity_count[last_tag] = entity_count.get(last_tag, 0) + score
......@@ -454,7 +457,7 @@ def print_results(scores: dict):
tt.print(results, header, style=tt.styles.markdown)
def run(annotation: str, prediction: str) -> dict:
def run(annotation: str, prediction: str, threshold: int) -> dict:
"""Compute recall and precision for each entity type found in annotation and/or prediction.
Each measure is given at document level, global score is a micro-average across entity types.
......@@ -486,7 +489,11 @@ def run(annotation: str, prediction: str) -> dict:
# Get nb match
matches = compute_matches(
annot_aligned, predict_aligned, labels_annot_aligned, labels_predict_aligned
annot_aligned,
predict_aligned,
labels_annot_aligned,
labels_predict_aligned,
threshold,
)
# Compute scores
......@@ -498,6 +505,17 @@ def run(annotation: str, prediction: str) -> dict:
return scores
def threshold_float_type(arg):
"""Type function for argparse."""
try:
f = float(arg)
except ValueError:
raise argparse.ArgumentTypeError("Must be a floating point number.")
if f < 0 or f > 1:
raise argparse.ArgumentTypeError("Must be between 0 and 1.")
return f
def main():
"""Get arguments and run."""
......@@ -510,9 +528,17 @@ def main():
parser.add_argument(
"-p", "--predict", help="Prediction in BIO format.", required=True
)
parser.add_argument(
"-t",
"--threshold",
help="Set a distance threshold for the match between gold and predicted entity.",
required=False,
default=0.3,
type=threshold_float_type,
)
args = parser.parse_args()
run(args.annot, args.predict)
run(args.annot, args.predict, args.threshold)
if __name__ == "__main__":
......
......@@ -3,6 +3,8 @@ import pytest
from nerval import evaluate
THRESHOLD = 0.30
fake_annot_aligned = "Gérard de -N-erval was bo-rn in Paris in 1808 -."
fake_predict_aligned = "G*rard de *N*erval ----bo*rn in Paris in 1833 *."
......@@ -153,6 +155,7 @@ expected_matches_nested_false = {"All": 2, "PER": 1, "LOC": 1}
fake_predict_aligned,
fake_annot_tags_aligned,
fake_predict_tags_aligned,
THRESHOLD,
),
expected_matches,
),
......@@ -162,6 +165,7 @@ expected_matches_nested_false = {"All": 2, "PER": 1, "LOC": 1}
fake_string_nested,
fake_tags_aligned_nested_perfect,
fake_tags_aligned_nested_perfect,
THRESHOLD,
),
expected_matches_nested_perfect,
),
......@@ -171,6 +175,7 @@ expected_matches_nested_false = {"All": 2, "PER": 1, "LOC": 1}
fake_string_nested,
fake_tags_aligned_nested_perfect,
fake_tags_aligned_nested_false,
THRESHOLD,
),
expected_matches_nested_false,
),
......@@ -182,4 +187,4 @@ def test_compute_matches(test_input, expected):
def test_compute_matches_empty_entry():
with pytest.raises(AssertionError):
evaluate.compute_matches(None, None, None, None)
evaluate.compute_matches(None, None, None, None, None)
......@@ -3,6 +3,8 @@ import pytest
from nerval import evaluate
THRESHOLD = 0.30
FAKE_ANNOT_BIO = "tests/test_annot.bio"
FAKE_PREDICT_BIO = "tests/test_predict.bio"
EMPTY_BIO = "tests/test_empty.bio"
......@@ -62,8 +64,8 @@ expected_scores = {
@pytest.mark.parametrize(
"test_input, expected",
[
((FAKE_ANNOT_BIO, FAKE_PREDICT_BIO), expected_scores),
((FAKE_BIO_NESTED, FAKE_BIO_NESTED), expected_scores_nested),
((FAKE_ANNOT_BIO, FAKE_PREDICT_BIO, THRESHOLD), expected_scores),
((FAKE_BIO_NESTED, FAKE_BIO_NESTED, THRESHOLD), expected_scores_nested),
],
)
def test_run(test_input, expected):
......@@ -73,9 +75,9 @@ def test_run(test_input, expected):
def test_run_empty_bio():
with pytest.raises(Exception):
evaluate.run(EMPTY_BIO, EMPTY_BIO)
evaluate.run(EMPTY_BIO, EMPTY_BIO, THRESHOLD)
def test_run_empty_entry():
with pytest.raises(TypeError):
evaluate.run(None, None)
evaluate.run(None, None, THRESHOLD)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment