From a633050ee293d0ade5ba78e859b94b561435493a Mon Sep 17 00:00:00 2001
From: manonBlanco <blanco@teklia.com>
Date: Mon, 17 Jul 2023 12:19:51 +0200
Subject: [PATCH] Always save model with the lowest "cer"

---
 dan/manager/training.py        | 21 ++-------------------
 dan/ocr/document/train.py      |  2 --
 docs/usage/train/parameters.md |  2 --
 tests/conftest.py              |  2 --
 4 files changed, 2 insertions(+), 25 deletions(-)

diff --git a/dan/manager/training.py b/dan/manager/training.py
index 1eba8670..40327d4e 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -522,7 +522,6 @@ class GenericTrainingManager:
             self.save_params()
         # init variables
         self.begin_time = time()
-        focus_metric_name = self.params["training_params"]["focus_metric"]
         nb_epochs = self.params["training_params"]["max_nb_epochs"]
         metric_names = self.params["training_params"]["train_metrics"]
 
@@ -642,25 +641,9 @@ class GenericTrainingManager:
                             )
                         if valid_set_name == self.params["training_params"][
                             "set_name_focus_metric"
-                        ] and (
-                            self.best is None
-                            or (
-                                eval_values[focus_metric_name] <= self.best
-                                and self.params["training_params"][
-                                    "expected_metric_value"
-                                ]
-                                == "low"
-                            )
-                            or (
-                                eval_values[focus_metric_name] >= self.best
-                                and self.params["training_params"][
-                                    "expected_metric_value"
-                                ]
-                                == "high"
-                            )
-                        ):
+                        ] and (self.best is None or eval_values["cer"] <= self.best):
                             self.save_model(epoch=num_epoch, name="best")
-                            self.best = eval_values[focus_metric_name]
+                            self.best = eval_values["cer"]
 
             # Handle curriculum learning update
             if self.dataset.train_dataset.curriculum_config:
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index d4c8330a..fb0ccf39 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -181,8 +181,6 @@ def get_config():
             "lr_schedulers": None,  # Learning rate schedulers
             "eval_on_valid": True,  # Whether to eval and logs metrics on validation set during training or not
             "eval_on_valid_interval": 5,  # Interval (in epochs) to evaluate during training
-            "focus_metric": "cer",  # Metrics to focus on to determine best epoch
-            "expected_metric_value": "low",  # ["high", "low"] What is best for the focus metric value
             "set_name_focus_metric": "{}-val".format(
                 dataset_name
             ),  # Which dataset to focus on to select best weights
diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md
index bf0632a9..f999209d 100644
--- a/docs/usage/train/parameters.md
+++ b/docs/usage/train/parameters.md
@@ -159,8 +159,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
 | `training_params.lr_schedulers`                         | Learning rate schedulers.                                                   | custom class | `None`                                      |
 | `training_params.eval_on_valid`                         | Whether to evaluate and log metrics on the validation set during training.  | `bool`       | `True`                                      |
 | `training_params.eval_on_valid_interval`                | Interval (in epochs) to evaluate during training.                           | `int`        | `5`                                         |
-| `training_params.focus_metric`                          | Metrics to focus on to determine best epoch.                                | `str`        | `cer`                                       |
-| `training_params.expected_metric_value`                 | Best value for the focus metric. Should be either `"high"` or `"low"`.      | `low`        | `cer`                                       |
 | `training_params.set_name_focus_metric`                 | Dataset to focus on to select best weights.                                 | `str`        |                                             |
 | `training_params.train_metrics`                         | List of metrics to compute during training.                                 | `list`       | `["loss_ce", "cer", "wer", "wer_no_punct"]` |
 | `training_params.eval_metrics`                          | List of metrics to compute during validation.                               | `list`       | `["cer", "wer", "wer_no_punct"]`            |
diff --git a/tests/conftest.py b/tests/conftest.py
index d85d11b4..ffc0dccb 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -125,8 +125,6 @@ def training_config():
             "lr_schedulers": None,  # Learning rate schedulers
             "eval_on_valid": True,  # Whether to eval and logs metrics on validation set during training or not
             "eval_on_valid_interval": 2,  # Interval (in epochs) to evaluate during training
-            "focus_metric": "cer",  # Metrics to focus on to determine best epoch
-            "expected_metric_value": "low",  # ["high", "low"] What is best for the focus metric value
             "set_name_focus_metric": "training-val",  # Which dataset to focus on to select best weights
             "train_metrics": [
                 "loss_ce",
-- 
GitLab