Skip to content
Snippets Groups Projects
Verified Commit 95b168cf authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Apply a633050e

parent dc482e24
No related branches found
No related tags found
No related merge requests found
...@@ -522,7 +522,6 @@ class GenericTrainingManager: ...@@ -522,7 +522,6 @@ class GenericTrainingManager:
self.save_params() self.save_params()
# init variables # init variables
self.begin_time = time() self.begin_time = time()
focus_metric_name = self.params["training_params"]["focus_metric"]
nb_epochs = self.params["training_params"]["max_nb_epochs"] nb_epochs = self.params["training_params"]["max_nb_epochs"]
metric_names = self.params["training_params"]["train_metrics"] metric_names = self.params["training_params"]["train_metrics"]
...@@ -642,25 +641,9 @@ class GenericTrainingManager: ...@@ -642,25 +641,9 @@ class GenericTrainingManager:
) )
if valid_set_name == self.params["training_params"][ if valid_set_name == self.params["training_params"][
"set_name_focus_metric" "set_name_focus_metric"
] and ( ] and (self.best is None or eval_values["cer"] <= self.best):
self.best is None
or (
eval_values[focus_metric_name] <= self.best
and self.params["training_params"][
"expected_metric_value"
]
== "low"
)
or (
eval_values[focus_metric_name] >= self.best
and self.params["training_params"][
"expected_metric_value"
]
== "high"
)
):
self.save_model(epoch=num_epoch, name="best") self.save_model(epoch=num_epoch, name="best")
self.best = eval_values[focus_metric_name] self.best = eval_values["cer"]
# Handle curriculum learning update # Handle curriculum learning update
if self.dataset.train_dataset.curriculum_config: if self.dataset.train_dataset.curriculum_config:
......
...@@ -181,8 +181,6 @@ def get_config(): ...@@ -181,8 +181,6 @@ def get_config():
"lr_schedulers": None, # Learning rate schedulers "lr_schedulers": None, # Learning rate schedulers
"eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not "eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not
"eval_on_valid_interval": 5, # Interval (in epochs) to evaluate during training "eval_on_valid_interval": 5, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "{}-val".format( "set_name_focus_metric": "{}-val".format(
dataset_name dataset_name
), # Which dataset to focus on to select best weights ), # Which dataset to focus on to select best weights
......
...@@ -168,8 +168,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa ...@@ -168,8 +168,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
| `training_params.lr_schedulers` | Learning rate schedulers. | custom class | `None` | | `training_params.lr_schedulers` | Learning rate schedulers. | custom class | `None` |
| `training_params.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` | | `training_params.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` |
| `training_params.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` | | `training_params.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` |
| `training_params.focus_metric` | Metrics to focus on to determine best epoch. | `str` | `cer` |
| `training_params.expected_metric_value` | Best value for the focus metric. Should be either `"high"` or `"low"`. | `low` | `cer` |
| `training_params.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | | | `training_params.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | |
| `training_params.train_metrics` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "wer", "wer_no_punct"]` | | `training_params.train_metrics` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "wer", "wer_no_punct"]` |
| `training_params.eval_metrics` | List of metrics to compute during validation. | `list` | `["cer", "wer", "wer_no_punct"]` | | `training_params.eval_metrics` | List of metrics to compute during validation. | `list` | `["cer", "wer", "wer_no_punct"]` |
......
...@@ -125,8 +125,6 @@ def training_config(): ...@@ -125,8 +125,6 @@ def training_config():
"lr_schedulers": None, # Learning rate schedulers "lr_schedulers": None, # Learning rate schedulers
"eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not "eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not
"eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training "eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "training-val", # Which dataset to focus on to select best weights "set_name_focus_metric": "training-val", # Which dataset to focus on to select best weights
"train_metrics": [ "train_metrics": [
"loss_ce", "loss_ce",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment