diff --git a/nerval/evaluate.py b/nerval/evaluate.py index 7c39ab760137b7b0a42a6f8570592f77a112d421..f32069d55b388ce1a5d67a37fb1c70b3033e5b68 100644 --- a/nerval/evaluate.py +++ b/nerval/evaluate.py @@ -9,7 +9,6 @@ import editdistance import edlib import termtables as tt -THRESHOLD = 0.30 NOT_ENTITY_TAG = "O" @@ -189,7 +188,11 @@ def look_for_further_entity_part(index, tag, characters, labels): def compute_matches( - annotation: str, prediction: str, labels_annot: list, labels_predict: list + annotation: str, + prediction: str, + labels_annot: list, + labels_predict: list, + threshold: int, ) -> dict: """Compute prediction score from annotation string to prediction string. @@ -324,7 +327,7 @@ def compute_matches( score = ( 1 if editdistance.eval(entity_ref, entity_compar) / len_entity - < THRESHOLD + <= threshold else 0 ) entity_count[last_tag] = entity_count.get(last_tag, 0) + score @@ -454,7 +457,7 @@ def print_results(scores: dict): tt.print(results, header, style=tt.styles.markdown) -def run(annotation: str, prediction: str) -> dict: +def run(annotation: str, prediction: str, threshold: int) -> dict: """Compute recall and precision for each entity type found in annotation and/or prediction. Each measure is given at document level, global score is a micro-average across entity types. @@ -486,7 +489,11 @@ def run(annotation: str, prediction: str) -> dict: # Get nb match matches = compute_matches( - annot_aligned, predict_aligned, labels_annot_aligned, labels_predict_aligned + annot_aligned, + predict_aligned, + labels_annot_aligned, + labels_predict_aligned, + threshold, ) # Compute scores @@ -498,6 +505,17 @@ def run(annotation: str, prediction: str) -> dict: return scores +def threshold_float_type(arg): + """Type function for argparse.""" + try: + f = float(arg) + except ValueError: + raise argparse.ArgumentTypeError("Must be a floating point number.") + if f < 0 or f > 1: + raise argparse.ArgumentTypeError("Must be between 0 and 1.") + return f + + def main(): """Get arguments and run.""" @@ -510,9 +528,17 @@ def main(): parser.add_argument( "-p", "--predict", help="Prediction in BIO format.", required=True ) + parser.add_argument( + "-t", + "--threshold", + help="Set a distance threshold for the match between gold and predicted entity.", + required=False, + default=0.3, + type=threshold_float_type, + ) args = parser.parse_args() - run(args.annot, args.predict) + run(args.annot, args.predict, args.threshold) if __name__ == "__main__": diff --git a/tests/test_compute_matches.py b/tests/test_compute_matches.py index 8d5d28a504418cf0b833a230f4a15aed670fe466..924310361ba3baf4504cbb1b6aa50d6b91b0fee9 100644 --- a/tests/test_compute_matches.py +++ b/tests/test_compute_matches.py @@ -3,6 +3,8 @@ import pytest from nerval import evaluate +THRESHOLD = 0.30 + fake_annot_aligned = "Gérard de -N-erval was bo-rn in Paris in 1808 -." fake_predict_aligned = "G*rard de *N*erval ----bo*rn in Paris in 1833 *." @@ -153,6 +155,7 @@ expected_matches_nested_false = {"All": 2, "PER": 1, "LOC": 1} fake_predict_aligned, fake_annot_tags_aligned, fake_predict_tags_aligned, + THRESHOLD, ), expected_matches, ), @@ -162,6 +165,7 @@ expected_matches_nested_false = {"All": 2, "PER": 1, "LOC": 1} fake_string_nested, fake_tags_aligned_nested_perfect, fake_tags_aligned_nested_perfect, + THRESHOLD, ), expected_matches_nested_perfect, ), @@ -171,6 +175,7 @@ expected_matches_nested_false = {"All": 2, "PER": 1, "LOC": 1} fake_string_nested, fake_tags_aligned_nested_perfect, fake_tags_aligned_nested_false, + THRESHOLD, ), expected_matches_nested_false, ), @@ -182,4 +187,4 @@ def test_compute_matches(test_input, expected): def test_compute_matches_empty_entry(): with pytest.raises(AssertionError): - evaluate.compute_matches(None, None, None, None) + evaluate.compute_matches(None, None, None, None, None) diff --git a/tests/test_run.py b/tests/test_run.py index a9ca42424637bdc967e6f86f9244d7fc12b2d7b9..d0ea6d59880e5da982a949b80818215af5358160 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -3,6 +3,8 @@ import pytest from nerval import evaluate +THRESHOLD = 0.30 + FAKE_ANNOT_BIO = "tests/test_annot.bio" FAKE_PREDICT_BIO = "tests/test_predict.bio" EMPTY_BIO = "tests/test_empty.bio" @@ -62,8 +64,8 @@ expected_scores = { @pytest.mark.parametrize( "test_input, expected", [ - ((FAKE_ANNOT_BIO, FAKE_PREDICT_BIO), expected_scores), - ((FAKE_BIO_NESTED, FAKE_BIO_NESTED), expected_scores_nested), + ((FAKE_ANNOT_BIO, FAKE_PREDICT_BIO, THRESHOLD), expected_scores), + ((FAKE_BIO_NESTED, FAKE_BIO_NESTED, THRESHOLD), expected_scores_nested), ], ) def test_run(test_input, expected): @@ -73,9 +75,9 @@ def test_run(test_input, expected): def test_run_empty_bio(): with pytest.raises(Exception): - evaluate.run(EMPTY_BIO, EMPTY_BIO) + evaluate.run(EMPTY_BIO, EMPTY_BIO, THRESHOLD) def test_run_empty_entry(): with pytest.raises(TypeError): - evaluate.run(None, None) + evaluate.run(None, None, THRESHOLD)