From ede796fc0a6e119541b6584bc6a6ee5a277b8b50 Mon Sep 17 00:00:00 2001
From: Solene Tarride <starride@teklia.com>
Date: Thu, 6 Jun 2024 16:13:47 +0000
Subject: [PATCH] Fix score computation when threshold=1.0

---
 nerval/evaluate.py | 30 +++++++++++++++++++++---------
 1 file changed, 21 insertions(+), 9 deletions(-)

diff --git a/nerval/evaluate.py b/nerval/evaluate.py
index d741a83..e89d621 100644
--- a/nerval/evaluate.py
+++ b/nerval/evaluate.py
@@ -23,6 +23,26 @@ PRED_COLUMN = "Prediction"
 CSV_HEADER = [ANNO_COLUMN, PRED_COLUMN]
 
 
+def match(annotation: str, prediction: str, threshold: float) -> bool:
+    """Test if two entities match based on their character edit distance.
+    Entities should be matched if both entity exist (e.g. not empty strings) and their Character Error Rate is below the threshold.
+    Otherwise they should not be matched.
+
+    Args:
+        annotation (str): ground-truth entity.
+        prediction (str): predicted entity.
+        threshold (float): matching threshold.
+
+    Returns:
+        bool: Whether to match these two entities.
+    """
+    return (
+        annotation != ""
+        and prediction != ""
+        and editdistance.eval(annotation, prediction) / len(annotation) <= threshold
+    )
+
+
 def compute_matches(
     annotation: str,
     prediction: str,
@@ -158,24 +178,17 @@ def compute_matches(
                 # Normalize collected strings
                 entity_ref = "".join(current_ref)
                 entity_ref = entity_ref.replace("-", "")
-                len_entity = len(entity_ref)
                 entity_compar = "".join(current_compar)
                 entity_compar = entity_compar.replace("-", "")
 
                 # One entity is counted as recognized (score of 1) if the Levenhstein distance between the expected and predicted entities
                 # represents less than 30% (THRESHOLD) of the length of the expected entity.
                 # Precision and recall will be computed for each category in comparing the numbers of recognized entities and expected entities
-                score = (
-                    1
-                    if editdistance.eval(entity_ref, entity_compar) / len_entity
-                    <= threshold
-                    else 0
-                )
+                score = int(match(entity_ref, entity_compar, threshold))
                 entity_count[last_tag] = entity_count.get(last_tag, 0) + score
                 entity_count[ALL_ENTITIES] += score
                 current_ref = []
                 current_compar = []
-
     return entity_count
 
 
@@ -263,7 +276,6 @@ def compute_scores(
             if (prec + rec == 0)
             else 2 * (prec * rec) / (prec + rec)
         )
-
         scores[tag]["predicted"] = nb_predict
         scores[tag]["matched"] = nb_match
         scores[tag]["P"] = prec
-- 
GitLab