Skip to content
Snippets Groups Projects
Commit a633050e authored by Manon Blanco's avatar Manon Blanco Committed by Mélodie Boillet
Browse files

Always save model with the lowest "cer"

parent cb58891e
No related branches found
No related tags found
1 merge request!210Always save model with the lowest "cer"
......@@ -522,7 +522,6 @@ class GenericTrainingManager:
self.save_params()
# init variables
self.begin_time = time()
focus_metric_name = self.params["training_params"]["focus_metric"]
nb_epochs = self.params["training_params"]["max_nb_epochs"]
metric_names = self.params["training_params"]["train_metrics"]
......@@ -642,25 +641,9 @@ class GenericTrainingManager:
)
if valid_set_name == self.params["training_params"][
"set_name_focus_metric"
] and (
self.best is None
or (
eval_values[focus_metric_name] <= self.best
and self.params["training_params"][
"expected_metric_value"
]
== "low"
)
or (
eval_values[focus_metric_name] >= self.best
and self.params["training_params"][
"expected_metric_value"
]
== "high"
)
):
] and (self.best is None or eval_values["cer"] <= self.best):
self.save_model(epoch=num_epoch, name="best")
self.best = eval_values[focus_metric_name]
self.best = eval_values["cer"]
# Handle curriculum learning update
if self.dataset.train_dataset.curriculum_config:
......
......@@ -181,8 +181,6 @@ def get_config():
"lr_schedulers": None, # Learning rate schedulers
"eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not
"eval_on_valid_interval": 5, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "{}-val".format(
dataset_name
), # Which dataset to focus on to select best weights
......
......@@ -159,8 +159,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
| `training_params.lr_schedulers` | Learning rate schedulers. | custom class | `None` |
| `training_params.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` |
| `training_params.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` |
| `training_params.focus_metric` | Metrics to focus on to determine best epoch. | `str` | `cer` |
| `training_params.expected_metric_value` | Best value for the focus metric. Should be either `"high"` or `"low"`. | `low` | `cer` |
| `training_params.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | |
| `training_params.train_metrics` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "wer", "wer_no_punct"]` |
| `training_params.eval_metrics` | List of metrics to compute during validation. | `list` | `["cer", "wer", "wer_no_punct"]` |
......
......@@ -125,8 +125,6 @@ def training_config():
"lr_schedulers": None, # Learning rate schedulers
"eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not
"eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "training-val", # Which dataset to focus on to select best weights
"train_metrics": [
"loss_ce",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment