From ede796fc0a6e119541b6584bc6a6ee5a277b8b50 Mon Sep 17 00:00:00 2001 From: Solene Tarride <starride@teklia.com> Date: Thu, 6 Jun 2024 16:13:47 +0000 Subject: [PATCH] Fix score computation when threshold=1.0 --- nerval/evaluate.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/nerval/evaluate.py b/nerval/evaluate.py index d741a83..e89d621 100644 --- a/nerval/evaluate.py +++ b/nerval/evaluate.py @@ -23,6 +23,26 @@ PRED_COLUMN = "Prediction" CSV_HEADER = [ANNO_COLUMN, PRED_COLUMN] +def match(annotation: str, prediction: str, threshold: float) -> bool: + """Test if two entities match based on their character edit distance. + Entities should be matched if both entity exist (e.g. not empty strings) and their Character Error Rate is below the threshold. + Otherwise they should not be matched. + + Args: + annotation (str): ground-truth entity. + prediction (str): predicted entity. + threshold (float): matching threshold. + + Returns: + bool: Whether to match these two entities. + """ + return ( + annotation != "" + and prediction != "" + and editdistance.eval(annotation, prediction) / len(annotation) <= threshold + ) + + def compute_matches( annotation: str, prediction: str, @@ -158,24 +178,17 @@ def compute_matches( # Normalize collected strings entity_ref = "".join(current_ref) entity_ref = entity_ref.replace("-", "") - len_entity = len(entity_ref) entity_compar = "".join(current_compar) entity_compar = entity_compar.replace("-", "") # One entity is counted as recognized (score of 1) if the Levenhstein distance between the expected and predicted entities # represents less than 30% (THRESHOLD) of the length of the expected entity. # Precision and recall will be computed for each category in comparing the numbers of recognized entities and expected entities - score = ( - 1 - if editdistance.eval(entity_ref, entity_compar) / len_entity - <= threshold - else 0 - ) + score = int(match(entity_ref, entity_compar, threshold)) entity_count[last_tag] = entity_count.get(last_tag, 0) + score entity_count[ALL_ENTITIES] += score current_ref = [] current_compar = [] - return entity_count @@ -263,7 +276,6 @@ def compute_scores( if (prec + rec == 0) else 2 * (prec * rec) / (prec + rec) ) - scores[tag]["predicted"] = nb_predict scores[tag]["matched"] = nb_match scores[tag]["P"] = prec -- GitLab