Skip to content
Snippets Groups Projects
Verified Commit 95b168cf authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Apply a633050e

parent dc482e24
No related branches found
No related tags found
No related merge requests found
......@@ -522,7 +522,6 @@ class GenericTrainingManager:
self.save_params()
# init variables
self.begin_time = time()
focus_metric_name = self.params["training_params"]["focus_metric"]
nb_epochs = self.params["training_params"]["max_nb_epochs"]
metric_names = self.params["training_params"]["train_metrics"]
......@@ -642,25 +641,9 @@ class GenericTrainingManager:
)
if valid_set_name == self.params["training_params"][
"set_name_focus_metric"
] and (
self.best is None
or (
eval_values[focus_metric_name] <= self.best
and self.params["training_params"][
"expected_metric_value"
]
== "low"
)
or (
eval_values[focus_metric_name] >= self.best
and self.params["training_params"][
"expected_metric_value"
]
== "high"
)
):
] and (self.best is None or eval_values["cer"] <= self.best):
self.save_model(epoch=num_epoch, name="best")
self.best = eval_values[focus_metric_name]
self.best = eval_values["cer"]
# Handle curriculum learning update
if self.dataset.train_dataset.curriculum_config:
......
......@@ -181,8 +181,6 @@ def get_config():
"lr_schedulers": None, # Learning rate schedulers
"eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not
"eval_on_valid_interval": 5, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "{}-val".format(
dataset_name
), # Which dataset to focus on to select best weights
......
......@@ -168,8 +168,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
| `training_params.lr_schedulers` | Learning rate schedulers. | custom class | `None` |
| `training_params.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` |
| `training_params.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` |
| `training_params.focus_metric` | Metrics to focus on to determine best epoch. | `str` | `cer` |
| `training_params.expected_metric_value` | Best value for the focus metric. Should be either `"high"` or `"low"`. | `low` | `cer` |
| `training_params.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | |
| `training_params.train_metrics` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "wer", "wer_no_punct"]` |
| `training_params.eval_metrics` | List of metrics to compute during validation. | `list` | `["cer", "wer", "wer_no_punct"]` |
......
......@@ -125,8 +125,6 @@ def training_config():
"lr_schedulers": None, # Learning rate schedulers
"eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not
"eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "training-val", # Which dataset to focus on to select best weights
"train_metrics": [
"loss_ce",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment