From 6cdfc750f6d3ebc9c1b4bde236e221daab3348bf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Fri, 9 Dec 2022 10:33:02 +0000
Subject: [PATCH] Normalize wer computation

---
 dan/manager/metrics.py             | 38 ++++++++++++++++++++++++------
 dan/manager/ocr.py                 |  2 +-
 dan/ocr/document/train.py          |  4 +++-
 dan/ocr/line/generate_synthetic.py |  8 ++++++-
 dan/ocr/line/train.py              |  9 ++++++-
 5 files changed, 50 insertions(+), 11 deletions(-)

diff --git a/dan/manager/metrics.py b/dan/manager/metrics.py
index 36d3a29f..fdcc0ea5 100644
--- a/dan/manager/metrics.py
+++ b/dan/manager/metrics.py
@@ -31,6 +31,7 @@ class MetricManager:
         self.linked_metrics = {
             "cer": ["edit_chars", "nb_chars"],
             "wer": ["edit_words", "nb_words"],
+            "wer_no_punct": ["edit_words_no_punct", "nb_words_no_punct"],
             "loer": [
                 "edit_graph",
                 "nb_nodes_and_edges",
@@ -127,6 +128,14 @@ class MetricManager:
                 )
                 if output:
                     display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"])
+            elif metric_name == "wer_no_punct":
+                value = np.sum(self.epoch_metrics["edit_words_no_punct"]) / np.sum(
+                    self.epoch_metrics["nb_words_no_punct"]
+                )
+                if output:
+                    display_values["nb_words_no_punct"] = np.sum(
+                        self.epoch_metrics["nb_words_no_punct"]
+                    )
             elif metric_name in [
                 "loss",
                 "loss_ctc",
@@ -183,6 +192,20 @@ class MetricManager:
                     for (gt, pred) in zip(split_gt, split_pred)
                 ]
                 metrics["nb_words"] = [len(gt) for gt in split_gt]
+            elif metric_name == "wer_no_punct":
+                split_gt = [
+                    format_string_for_wer(gt, self.layout_tokens, remove_punct=True)
+                    for gt in values["str_y"]
+                ]
+                split_pred = [
+                    format_string_for_wer(pred, self.layout_tokens, remove_punct=True)
+                    for pred in values["str_x"]
+                ]
+                metrics["edit_words_no_punct"] = [
+                    edit_wer_from_formatted_split_text(gt, pred)
+                    for (gt, pred) in zip(split_gt, split_pred)
+                ]
+                metrics["nb_words_no_punct"] = [len(gt) for gt in split_gt]
             elif metric_name in [
                 "loss_ctc",
                 "loss_ce",
@@ -258,22 +281,23 @@ def nb_chars_cer_from_string(gt, layout_tokens=None):
     return len(format_string_for_cer(gt, layout_tokens))
 
 
-def edit_wer_from_string(gt, pred, layout_tokens=None):
+def edit_wer_from_string(gt, pred, layout_tokens=None, remove_punct=False):
     """
     Format and compute edit distance between two strings at word level
     """
-    split_gt = format_string_for_wer(gt, layout_tokens)
-    split_pred = format_string_for_wer(pred, layout_tokens)
+    split_gt = format_string_for_wer(gt, layout_tokens, remove_punct)
+    split_pred = format_string_for_wer(pred, layout_tokens, remove_punct)
     return edit_wer_from_formatted_split_text(split_gt, split_pred)
 
 
-def format_string_for_wer(str, layout_tokens):
+def format_string_for_wer(str, layout_tokens, remove_punct=False):
     """
     Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
     """
-    str = re.sub(
-        r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", r" \1 ", str
-    )  # punctuation processed as word
+    if remove_punct:
+        str = re.sub(
+            r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str
+        )  # remove punctuation
     if layout_tokens is not None:
         str = keep_all_but_tokens(
             str, layout_tokens
diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index 775fce63..93b966de 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -369,7 +369,7 @@ class OCRDataset(GenericDataset):
                 min(nb_samples, self.synthetic_data["num_steps_proba"]),
                 self.synthetic_data["num_steps_proba"],
             )
-            return proba
+        return proba
 
     def generate_synthetic_page_sample(self):
         max_nb_lines_per_page = self.get_syn_max_lines()
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index b173ba99..86af4d63 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -32,7 +32,7 @@ def train_and_test(rank, params):
     model.params["training_params"]["load_epoch"] = "best"
     model.load_model()
 
-    metrics = ["cer", "wer", "time"]
+    metrics = ["cer", "wer", "wer_no_punct", "time"]
     for dataset_name in params["dataset_params"]["datasets"].keys():
         for set_name in ["test", "val", "train"]:
             model.predict(
@@ -208,12 +208,14 @@ def run():
                 "loss_ce",
                 "cer",
                 "wer",
+                "wer_no_punct",
                 "syn_max_lines",
                 "syn_prob_lines",
             ],  # Metrics name for training
             "eval_metrics": [
                 "cer",
                 "wer",
+                "wer_no_punct",
             ],  # Metrics name for evaluation on validation set during training
             "force_cpu": False,  # True for debug purposes
             "max_char_prediction": 1000,  # max number of token prediction
diff --git a/dan/ocr/line/generate_synthetic.py b/dan/ocr/line/generate_synthetic.py
index 4efcd112..879d809f 100644
--- a/dan/ocr/line/generate_synthetic.py
+++ b/dan/ocr/line/generate_synthetic.py
@@ -136,11 +136,17 @@ def run():
             "focus_metric": "cer",  # Metrics to focus on to determine best epoch
             "expected_metric_value": "low",  # ["high", "low"] What is best for the focus metric value
             "set_name_focus_metric": "{}-val".format(dataset_name),
-            "train_metrics": ["loss_ctc", "cer", "wer"],  # Metrics name for training
+            "train_metrics": [
+                "loss_ctc",
+                "cer",
+                "wer",
+                "wer_no_punct",
+            ],  # Metrics name for training
             "eval_metrics": [
                 "loss_ctc",
                 "cer",
                 "wer",
+                "wer_no_punct",
             ],  # Metrics name for evaluation on validation set during training
             "force_cpu": False,  # True for debug purposes to run on cpu only
         },
diff --git a/dan/ocr/line/train.py b/dan/ocr/line/train.py
index e76a54f3..86d6009e 100644
--- a/dan/ocr/line/train.py
+++ b/dan/ocr/line/train.py
@@ -37,6 +37,7 @@ def train_and_test(rank, params):
     metrics = [
         "cer",
         "wer",
+        "wer_no_punct",
         "time",
     ]
     for dataset_name in params["dataset_params"]["datasets"].keys():
@@ -178,11 +179,17 @@ def run():
             "set_name_focus_metric": "{}-val".format(
                 dataset_name
             ),  # Which dataset to focus on to select best weights
-            "train_metrics": ["loss_ctc", "cer", "wer"],  # Metrics name for training
+            "train_metrics": [
+                "loss_ctc",
+                "cer",
+                "wer",
+                "wer_no_punct",
+            ],  # Metrics name for training
             "eval_metrics": [
                 "loss_ctc",
                 "cer",
                 "wer",
+                "wer_no_punct",
             ],  # Metrics name for evaluation on validation set during training
             "force_cpu": False,  # True for debug purposes to run on cpu only
         },
-- 
GitLab