From 7a4af1eb3033511a43533a4b5239e7be6501fb0a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Mon, 5 Dec 2022 09:51:41 +0000
Subject: [PATCH] monitor probability of synthetic documents

---
 dan/manager/metrics.py    |  9 ++++++-
 dan/manager/ocr.py        | 50 +++++++++++++++++++++++----------------
 dan/manager/training.py   |  3 +++
 dan/ocr/document/train.py |  1 +
 4 files changed, 42 insertions(+), 21 deletions(-)

diff --git a/dan/manager/metrics.py b/dan/manager/metrics.py
index c9e3a642..36d3a29f 100644
--- a/dan/manager/metrics.py
+++ b/dan/manager/metrics.py
@@ -127,7 +127,13 @@ class MetricManager:
                 )
                 if output:
                     display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"])
-            elif metric_name in ["loss", "loss_ctc", "loss_ce", "syn_max_lines"]:
+            elif metric_name in [
+                "loss",
+                "loss_ctc",
+                "loss_ce",
+                "syn_max_lines",
+                "syn_prob_lines",
+            ]:
                 value = np.average(
                     self.epoch_metrics[metric_name],
                     weights=np.array(self.epoch_metrics["nb_samples"]),
@@ -182,6 +188,7 @@ class MetricManager:
                 "loss_ce",
                 "loss",
                 "syn_max_lines",
+                "syn_prob_lines",
             ]:
                 metrics[metric_name] = [
                     values[metric_name],
diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index 43d46894..d097eae3 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -284,9 +284,37 @@ class OCRDataset(GenericDataset):
 
     def generate_synthetic_data(self, sample):
         config = self.params["config"]["synthetic_data"]
+        proba = self.get_syn_proba_lines()
 
-        if not (config["init_proba"] == config["end_proba"] == 1):
+        if rand() > proba:
+            return sample
+
+        if "mode" in config and config["mode"] == "line_hw_to_printed":
+            sample["img"] = self.generate_typed_text_line_image(sample["label"])
+            return sample
+
+        return self.generate_synthetic_page_sample()
+
+    def get_syn_max_lines(self):
+        config = self.params["config"]["synthetic_data"]
+        if config["curriculum"]:
             nb_samples = self.training_info["step"] * self.params["batch_size"]
+            max_nb_lines = min(
+                config["max_nb_lines"],
+                (nb_samples - config["curr_start"]) // config["curr_step"] + 1,
+            )
+            return max(config["min_nb_lines"], max_nb_lines)
+        return config["max_nb_lines"]
+
+    def get_syn_proba_lines(self):
+        config = self.params["config"]["synthetic_data"]
+
+        if config["init_proba"] == config["end_proba"]:
+            return config["init_proba"]
+
+        else:
+            nb_samples = self.training_info["step"] * self.params["batch_size"]
+
             if config["start_scheduler_at_max_line"]:
                 max_step = config["num_steps_proba"]
                 current_step = max(
@@ -315,25 +343,7 @@ class OCRDataset(GenericDataset):
                     min(nb_samples, config["num_steps_proba"]),
                     config["num_steps_proba"],
                 )
-            if rand() > proba:
-                return sample
-
-        if "mode" in config and config["mode"] == "line_hw_to_printed":
-            sample["img"] = self.generate_typed_text_line_image(sample["label"])
-            return sample
-
-        return self.generate_synthetic_page_sample()
-
-    def get_syn_max_lines(self):
-        config = self.params["config"]["synthetic_data"]
-        if config["curriculum"]:
-            nb_samples = self.training_info["step"] * self.params["batch_size"]
-            max_nb_lines = min(
-                config["max_nb_lines"],
-                (nb_samples - config["curr_start"]) // config["curr_step"] + 1,
-            )
-            return max(config["min_nb_lines"], max_nb_lines)
-        return config["max_nb_lines"]
+            return proba
 
     def generate_synthetic_page_sample(self):
         config = self.params["config"]["synthetic_data"]
diff --git a/dan/manager/training.py b/dan/manager/training.py
index 682e3091..41ba05cb 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -1160,6 +1160,9 @@ class Manager(OCRManager):
             "syn_max_lines": self.dataset.train_dataset.get_syn_max_lines()
             if self.params["dataset_params"]["config"]["synthetic_data"]
             else 0,
+            "syn_prob_lines": self.dataset.train_dataset.get_syn_proba_lines()
+            if self.params["dataset_params"]["config"]["synthetic_data"]
+            else 0,
         }
 
         return values
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index a4d3f5e8..b173ba99 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -209,6 +209,7 @@ def run():
                 "cer",
                 "wer",
                 "syn_max_lines",
+                "syn_prob_lines",
             ],  # Metrics name for training
             "eval_metrics": [
                 "cer",
-- 
GitLab