From 4c1ccf1047c033c8cf7086adf8b579e50bb17ee4 Mon Sep 17 00:00:00 2001
From: Manon blanco <blanco@teklia.com>
Date: Tue, 18 Jul 2023 08:34:16 +0000
Subject: [PATCH] Generate the correct `parameters.yml` file directly during
 training

---
 dan/manager/training.py              | 71 ++++++++++++++++------------
 docs/get_started/training.md         | 29 +-----------
 tests/data/prediction/parameters.yml |  1 -
 tests/test_training.py               | 40 +++++++++++++++-
 4 files changed, 82 insertions(+), 59 deletions(-)

diff --git a/dan/manager/training.py b/dan/manager/training.py
index 40327d4e..a3e1af66 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -1,8 +1,7 @@
 # -*- coding: utf-8 -*-
-import json
 import os
 import random
-from copy import deepcopy
+from enum import Enum
 from time import time
 
 import numpy as np
@@ -454,37 +453,51 @@ class GenericTrainingManager:
 
     def save_params(self):
         """
-        Output text file containing a summary of all hyperparameters chosen for the training
+        Output yaml file containing a summary of all hyperparameters chosen for the training
         """
-
-        def compute_nb_params(module):
-            return sum([np.prod(p.size()) for p in list(module.parameters())])
-
-        def class_to_str_dict(my_dict):
-            for key in my_dict.keys():
-                if callable(my_dict[key]):
-                    my_dict[key] = my_dict[key].__name__
-                elif isinstance(my_dict[key], np.ndarray):
-                    my_dict[key] = my_dict[key].tolist()
-                elif isinstance(my_dict[key], dict):
-                    my_dict[key] = class_to_str_dict(my_dict[key])
-            return my_dict
-
-        path = os.path.join(self.paths["results"], "params")
+        path = os.path.join(self.paths["results"], "parameters.yml")
         if os.path.isfile(path):
             return
-        params = class_to_str_dict(my_dict=deepcopy(self.params))
-        total_params = 0
-        for model_name in self.models.keys():
-            current_params = compute_nb_params(self.models[model_name])
-            params["model_params"]["models"][model_name] = [
-                params["model_params"]["models"][model_name],
-                "{:,}".format(current_params),
-            ]
-            total_params += current_params
-        params["model_params"]["total_params"] = "{:,}".format(total_params)
+
+        params = {
+            "parameters": {
+                "max_char_prediction": self.params["training_params"][
+                    "max_char_prediction"
+                ],
+                "encoder": {
+                    "dropout": self.params["model_params"]["dropout"],
+                },
+                "decoder": {
+                    key: self.params["model_params"][key]
+                    for key in [
+                        "enc_dim",
+                        "l_max",
+                        "h_max",
+                        "w_max",
+                        "dec_num_layers",
+                        "dec_num_heads",
+                        "dec_res_dropout",
+                        "dec_pred_dropout",
+                        "dec_att_dropout",
+                        "dec_dim_feedforward",
+                        "vocab_size",
+                        "attention_win",
+                    ]
+                },
+                "preprocessings": [
+                    {
+                        key: value.value if isinstance(value, Enum) else value
+                        for key, value in preprocessing.items()
+                    }
+                    for preprocessing in self.params["dataset_params"]["config"].get(
+                        "preprocessings", []
+                    )
+                ],
+            },
+        }
+
         with open(path, "w") as f:
-            json.dump(params, f, indent=4)
+            yaml.dump(params, f)
 
     def backward_loss(self, loss, retain_graph=False):
         self.scaler.scale(loss).backward(retain_graph=retain_graph)
diff --git a/docs/get_started/training.md b/docs/get_started/training.md
index b7099281..9f034f0d 100644
--- a/docs/get_started/training.md
+++ b/docs/get_started/training.md
@@ -45,31 +45,4 @@ Once the training is complete, you can apply a trained DAN model on an image.
 
 To do this, you will need to:
 
-1. Create a `parameters.yml` file using the parameters saved during training in the `params` file, located in `{training_params.output_folder}/results`. This file should have the following format:
-```yml
-version: 0.0.1
-parameters:
-  max_char_prediction: int
-  encoder:
-    dropout: float
-  decoder:
-    enc_dim: int
-    l_max: int
-    dec_pred_dropout: float
-    attention_win: int
-    vocab_size: int
-    h_max: int
-    w_max: int
-    dec_num_layers: int
-    dec_dim_feedforward: int
-    dec_num_heads: int
-    dec_att_dropout: float
-    dec_res_dropout: float
-  preprocessings:
-    - type: str
-      max_height: int
-      max_width: int
-      fixed_height: int
-      fixed_width: int
-```
-2. Apply a trained DAN model on an image using the [predict command](../usage/predict.md).
+1. Apply a trained DAN model on an image using the [predict command](../usage/predict.md).
diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml
index 32ffff56..db1880b0 100644
--- a/tests/data/prediction/parameters.yml
+++ b/tests/data/prediction/parameters.yml
@@ -1,5 +1,4 @@
 ---
-version: 0.0.1
 parameters:
   max_char_prediction: 200
   encoder:
diff --git a/tests/test_training.py b/tests/test_training.py
index f3030fe2..92a3c6bb 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -9,7 +9,7 @@ from tests.conftest import FIXTURES
 
 
 @pytest.mark.parametrize(
-    "expected_best_model_name, expected_last_model_name, training_res, val_res, test_res",
+    "expected_best_model_name, expected_last_model_name, training_res, val_res, test_res, params_res",
     (
         (
             "best_0.pt",
@@ -41,6 +41,33 @@ from tests.conftest import FIXTURES
                 "wer_no_punct": 1.0,
                 "nb_samples": 2,
             },
+            {
+                "parameters": {
+                    "max_char_prediction": 30,
+                    "encoder": {"dropout": 0.5},
+                    "decoder": {
+                        "enc_dim": 256,
+                        "l_max": 15000,
+                        "h_max": 500,
+                        "w_max": 1000,
+                        "dec_num_layers": 8,
+                        "dec_num_heads": 4,
+                        "dec_res_dropout": 0.1,
+                        "dec_pred_dropout": 0.1,
+                        "dec_att_dropout": 0.1,
+                        "dec_dim_feedforward": 256,
+                        "vocab_size": 96,
+                        "attention_win": 100,
+                    },
+                    "preprocessings": [
+                        {
+                            "max_height": 2000,
+                            "max_width": 2000,
+                            "type": "max_resize",
+                        }
+                    ],
+                },
+            },
         ),
     ),
 )
@@ -50,6 +77,7 @@ def test_train_and_test(
     training_res,
     val_res,
     test_res,
+    params_res,
     training_config,
     tmp_path,
 ):
@@ -146,3 +174,13 @@ def test_train_and_test(
                 if "time" not in metric
             }
             assert res == expected_res
+
+    # Check that the parameters file is correct
+    with (
+        tmp_path
+        / training_config["training_params"]["output_folder"]
+        / "results"
+        / "parameters.yml"
+    ).open() as f:
+        res = yaml.safe_load(f)
+        assert res == params_res
-- 
GitLab