From a26125a566274af201dc6d210a3e2283141b74e7 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Fri, 12 May 2023 07:45:36 +0000
Subject: [PATCH] Save metrics results to YAML instead of plain text

---
 dan/manager/metrics.py  | 75 ++++++++++++++++++++++++-----------------
 dan/manager/training.py | 10 +++---
 tests/test_training.py  | 68 +++++++++++++++++++------------------
 3 files changed, 85 insertions(+), 68 deletions(-)

diff --git a/dan/manager/metrics.py b/dan/manager/metrics.py
index fdcc0ea5..8c4682fe 100644
--- a/dan/manager/metrics.py
+++ b/dan/manager/metrics.py
@@ -87,54 +87,66 @@ class MetricManager:
             value = None
             if output:
                 if metric_name in ["nb_samples", "weights"]:
-                    value = np.sum(self.epoch_metrics[metric_name])
+                    value = int(np.sum(self.epoch_metrics[metric_name]))
                 elif metric_name in [
                     "time",
                 ]:
-                    total_time = np.sum(self.epoch_metrics[metric_name])
-                    sample_time = total_time / np.sum(self.epoch_metrics["nb_samples"])
-                    display_values["sample_time"] = round(sample_time, 4)
-                    value = total_time
+                    value = int(np.sum(self.epoch_metrics[metric_name]))
+                    sample_time = value / np.sum(self.epoch_metrics["nb_samples"])
+                    display_values["sample_time"] = float(round(sample_time, 4))
                 elif metric_name == "loer":
-                    display_values["pper"] = round(
-                        np.sum(self.epoch_metrics["nb_pp_op_layout"])
-                        / np.sum(self.epoch_metrics["nb_gt_layout_token"]),
-                        4,
+                    display_values["pper"] = float(
+                        round(
+                            np.sum(self.epoch_metrics["nb_pp_op_layout"])
+                            / np.sum(self.epoch_metrics["nb_gt_layout_token"]),
+                            4,
+                        )
                     )
                 elif metric_name == "map_cer_per_class":
-                    value = compute_global_mAP_per_class(self.epoch_metrics["map_cer"])
+                    value = float(
+                        compute_global_mAP_per_class(self.epoch_metrics["map_cer"])
+                    )
                     for key in value.keys():
-                        display_values["map_cer_" + key] = round(value[key], 4)
+                        display_values["map_cer_" + key] = float(round(value[key], 4))
                     continue
                 elif metric_name == "layout_precision_per_class_per_threshold":
-                    value = compute_global_precision_per_class_per_threshold(
-                        self.epoch_metrics["map_cer"]
+                    value = float(
+                        compute_global_precision_per_class_per_threshold(
+                            self.epoch_metrics["map_cer"]
+                        )
                     )
                     for key_class in value.keys():
                         for threshold in value[key_class].keys():
                             display_values[
                                 "map_cer_{}_{}".format(key_class, threshold)
-                            ] = round(value[key_class][threshold], 4)
+                            ] = float(round(value[key_class][threshold], 4))
                     continue
             if metric_name == "cer":
-                value = np.sum(self.epoch_metrics["edit_chars"]) / np.sum(
-                    self.epoch_metrics["nb_chars"]
+                value = float(
+                    np.sum(self.epoch_metrics["edit_chars"])
+                    / np.sum(self.epoch_metrics["nb_chars"])
                 )
                 if output:
-                    display_values["nb_chars"] = np.sum(self.epoch_metrics["nb_chars"])
+                    display_values["nb_chars"] = int(
+                        np.sum(self.epoch_metrics["nb_chars"])
+                    )
             elif metric_name == "wer":
-                value = np.sum(self.epoch_metrics["edit_words"]) / np.sum(
-                    self.epoch_metrics["nb_words"]
+                value = float(
+                    np.sum(self.epoch_metrics["edit_words"])
+                    / np.sum(self.epoch_metrics["nb_words"])
                 )
                 if output:
-                    display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"])
+                    display_values["nb_words"] = int(
+                        np.sum(self.epoch_metrics["nb_words"])
+                    )
             elif metric_name == "wer_no_punct":
-                value = np.sum(self.epoch_metrics["edit_words_no_punct"]) / np.sum(
-                    self.epoch_metrics["nb_words_no_punct"]
+                value = float(
+                    np.sum(self.epoch_metrics["edit_words_no_punct"])
+                    / np.sum(self.epoch_metrics["nb_words_no_punct"])
                 )
                 if output:
-                    display_values["nb_words_no_punct"] = np.sum(
-                        self.epoch_metrics["nb_words_no_punct"]
+                    display_values["nb_words_no_punct"] = int(
+                        np.sum(self.epoch_metrics["nb_words_no_punct"])
                     )
             elif metric_name in [
                 "loss",
@@ -143,15 +155,18 @@ class MetricManager:
                 "syn_max_lines",
                 "syn_prob_lines",
             ]:
-                value = np.average(
-                    self.epoch_metrics[metric_name],
-                    weights=np.array(self.epoch_metrics["nb_samples"]),
+                value = float(
+                    np.average(
+                        self.epoch_metrics[metric_name],
+                        weights=np.array(self.epoch_metrics["nb_samples"]),
+                    )
                 )
             elif metric_name == "map_cer":
-                value = compute_global_mAP(self.epoch_metrics[metric_name])
+                value = float(compute_global_mAP(self.epoch_metrics[metric_name]))
             elif metric_name == "loer":
-                value = np.sum(self.epoch_metrics["edit_graph"]) / np.sum(
-                    self.epoch_metrics["nb_nodes_and_edges"]
+                value = float(
+                    np.sum(self.epoch_metrics["edit_graph"])
+                    / np.sum(self.epoch_metrics["nb_nodes_and_edges"])
                 )
             elif value is None:
                 continue
diff --git a/dan/manager/training.py b/dan/manager/training.py
index dc9267b9..c52cebd6 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -12,6 +12,7 @@ import numpy as np
 import torch
 import torch.distributed as dist
 import torch.multiprocessing as mp
+import yaml
 from PIL import Image
 from torch.cuda.amp import GradScaler, autocast
 from torch.nn import CrossEntropyLoss
@@ -865,22 +866,21 @@ class GenericTrainingManager:
             metrics = self.metric_manager[custom_name].get_display_values(output=True)
             path = os.path.join(
                 self.paths["results"],
-                "predict_{}_{}.txt".format(custom_name, self.latest_epoch),
+                "predict_{}_{}.yaml".format(custom_name, self.latest_epoch),
             )
             with open(path, "w") as f:
-                for metric_name in metrics.keys():
-                    f.write("{}: {}\n".format(metric_name, metrics[metric_name]))
+                yaml.dump(metrics, stream=f)
 
             # Log mlflow artifacts
             mlflow.log_artifact(path, "predictions")
 
     def output_pred(self, name):
         path = os.path.join(
-            self.paths["results"], "pred_{}_{}.txt".format(name, self.latest_epoch)
+            self.paths["results"], "pred_{}_{}.yaml".format(name, self.latest_epoch)
         )
         pred = "\n".join(self.metric_manager[name].get("pred"))
         with open(path, "w") as f:
-            f.write(pred)
+            yaml.dump(pred, stream=f)
 
     def launch_ddp(self):
         """
diff --git a/tests/test_training.py b/tests/test_training.py
index a674090a..b2c4ebbb 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -2,6 +2,7 @@
 
 import pytest
 import torch
+import yaml
 
 from dan.ocr.document.train import train_and_test
 from tests.conftest import FIXTURES
@@ -13,33 +14,33 @@ from tests.conftest import FIXTURES
         (
             "best_0.pt",
             "last_3.pt",
-            [
-                "nb_chars: 43",
-                "cer: 1.2791",
-                "nb_words: 9",
-                "wer: 1.0",
-                "nb_words_no_punct: 9",
-                "wer_no_punct: 1.0",
-                "nb_samples: 2",
-            ],
-            [
-                "nb_chars: 41",
-                "cer: 1.2683",
-                "nb_words: 9",
-                "wer: 1.0",
-                "nb_words_no_punct: 9",
-                "wer_no_punct: 1.0",
-                "nb_samples: 2",
-            ],
-            [
-                "nb_chars: 49",
-                "cer: 1.1429",
-                "nb_words: 9",
-                "wer: 1.0",
-                "nb_words_no_punct: 9",
-                "wer_no_punct: 1.0",
-                "nb_samples: 2",
-            ],
+            {
+                "nb_chars": 43,
+                "cer": 1.2791,
+                "nb_words": 9,
+                "wer": 1.0,
+                "nb_words_no_punct": 9,
+                "wer_no_punct": 1.0,
+                "nb_samples": 2,
+            },
+            {
+                "nb_chars": 41,
+                "cer": 1.2683,
+                "nb_words": 9,
+                "wer": 1.0,
+                "nb_words_no_punct": 9,
+                "wer_no_punct": 1.0,
+                "nb_samples": 2,
+            },
+            {
+                "nb_chars": 49,
+                "cer": 1.1429,
+                "nb_words": 9,
+                "wer": 1.0,
+                "nb_words_no_punct": 9,
+                "wer_no_punct": 1.0,
+                "nb_samples": 2,
+            },
         ),
     ),
 )
@@ -136,11 +137,12 @@ def test_train_and_test(
             tmp_path
             / training_config["training_params"]["output_folder"]
             / "results"
-            / f"predict_training-{split_name}_0.txt"
-        ).open(
-            "r",
-        ) as f:
-            res = f.read().splitlines()
+            / f"predict_training-{split_name}_0.yaml"
+        ).open() as f:
             # Remove the times from the results as they vary
-            res = [result for result in res if "time" not in result]
+            res = {
+                metric: value
+                for metric, value in yaml.safe_load(f).items()
+                if "time" not in metric
+            }
             assert res == expected_res
-- 
GitLab