From a035263782e8a94907981df86efbfb2939c32c65 Mon Sep 17 00:00:00 2001
From: Eva Bardou <bardou@teklia.com>
Date: Thu, 23 Nov 2023 11:06:21 +0000
Subject: [PATCH] Evaluate NER tokens distance

---
 dan/ocr/evaluate.py        |  3 +++
 dan/ocr/manager/metrics.py | 17 +++++++++++++++++
 dan/ocr/train.py           |  6 ++++++
 3 files changed, 26 insertions(+)

diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py
index f8b09a23..6cc95a03 100644
--- a/dan/ocr/evaluate.py
+++ b/dan/ocr/evaluate.py
@@ -51,6 +51,9 @@ def eval(rank, config, mlflow_logging):
     model.load_model()
 
     metrics = ["cer", "wer", "wer_no_punct", "time"]
+    if config["dataset"]["tokens"] is not None:
+        metrics.append("ner")
+
     for dataset_name in config["dataset"]["datasets"]:
         for set_name in ["test", "val", "train"]:
             logger.info(f"Evaluating on set `{set_name}`")
diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py
index f48da65c..f6b6ec74 100644
--- a/dan/ocr/manager/metrics.py
+++ b/dan/ocr/manager/metrics.py
@@ -32,6 +32,8 @@ class MetricManager:
                 + list(map(attrgetter("end"), tokens.values()))
             )
             self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])")
+            self.keep_tokens: re.Pattern = re.compile(r"([^" + layout_tokens + "])")
+
         self.metric_names: List[str] = metric_names
         self.epoch_metrics = defaultdict(list)
 
@@ -69,6 +71,12 @@ class MetricManager:
         text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text)
         return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip()
 
+    def format_string_for_ner(self, text: str):
+        """
+        Format string for NER computation: only keep layout tokens
+        """
+        return self.keep_tokens.sub("", text)
+
     def update_metrics(self, batch_metrics):
         """
         Add batch metrics to the metrics
@@ -100,6 +108,8 @@ class MetricManager:
                 case "wer" | "wer_no_punct":
                     suffix = metric_name[3:]
                     num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix
+                case "ner":
+                    num_name, denom_name = "edit_tokens", "nb_tokens"
                 case "loss" | "loss_ce":
                     display_values[metric_name] = round(
                         float(
@@ -152,6 +162,13 @@ class MetricManager:
                         map(editdistance.eval, split_gt, split_pred)
                     )
                     metrics["nb_words" + suffix] = list(map(len, split_gt))
+                case "ner":
+                    split_gt = list(map(self.format_string_for_ner, gt))
+                    split_pred = list(map(self.format_string_for_ner, prediction))
+                    metrics["edit_tokens"] = list(
+                        map(editdistance.eval, split_gt, split_pred)
+                    )
+                    metrics["nb_tokens"] = list(map(len, split_gt))
                 case "loss" | "loss_ce":
                     metrics[metric_name] = [
                         values[metric_name],
diff --git a/dan/ocr/train.py b/dan/ocr/train.py
index 8385111c..ffc5aa15 100644
--- a/dan/ocr/train.py
+++ b/dan/ocr/train.py
@@ -34,6 +34,12 @@ def train(rank, params, mlflow_logging=False):
     model = Manager(params)
     model.load_model()
 
+    if params["dataset"]["tokens"] is not None:
+        if "ner" not in params["training"]["metrics"]["train"]:
+            params["training"]["metrics"]["train"].append("ner")
+        if "ner" not in params["training"]["metrics"]["eval"]:
+            params["training"]["metrics"]["eval"].append("ner")
+
     if mlflow_logging:
         logger.info("MLflow logging enabled")
 
-- 
GitLab