diff --git a/dan/manager/training.py b/dan/manager/training.py index 1eba86706aafc091c556eee92901271b548e7598..40327d4e4d833b91edd72142c5939333ccbe38db 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -522,7 +522,6 @@ class GenericTrainingManager: self.save_params() # init variables self.begin_time = time() - focus_metric_name = self.params["training_params"]["focus_metric"] nb_epochs = self.params["training_params"]["max_nb_epochs"] metric_names = self.params["training_params"]["train_metrics"] @@ -642,25 +641,9 @@ class GenericTrainingManager: ) if valid_set_name == self.params["training_params"][ "set_name_focus_metric" - ] and ( - self.best is None - or ( - eval_values[focus_metric_name] <= self.best - and self.params["training_params"][ - "expected_metric_value" - ] - == "low" - ) - or ( - eval_values[focus_metric_name] >= self.best - and self.params["training_params"][ - "expected_metric_value" - ] - == "high" - ) - ): + ] and (self.best is None or eval_values["cer"] <= self.best): self.save_model(epoch=num_epoch, name="best") - self.best = eval_values[focus_metric_name] + self.best = eval_values["cer"] # Handle curriculum learning update if self.dataset.train_dataset.curriculum_config: diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index d4c8330acb0c0437f1f9dbfd7a8a4b902b8901f1..fb0ccf39aeb0d2e71394de56d22bc6b5e00a44f4 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -181,8 +181,6 @@ def get_config(): "lr_schedulers": None, # Learning rate schedulers "eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not "eval_on_valid_interval": 5, # Interval (in epochs) to evaluate during training - "focus_metric": "cer", # Metrics to focus on to determine best epoch - "expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value "set_name_focus_metric": "{}-val".format( dataset_name ), # Which dataset to focus on to select best weights diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md index bf0632a96c1d2a510eda0dddc3e6bf9aa0cdaa1b..f999209d232bfeff0af32191d70b74f8852b829e 100644 --- a/docs/usage/train/parameters.md +++ b/docs/usage/train/parameters.md @@ -159,8 +159,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa | `training_params.lr_schedulers` | Learning rate schedulers. | custom class | `None` | | `training_params.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` | | `training_params.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` | -| `training_params.focus_metric` | Metrics to focus on to determine best epoch. | `str` | `cer` | -| `training_params.expected_metric_value` | Best value for the focus metric. Should be either `"high"` or `"low"`. | `low` | `cer` | | `training_params.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | | | `training_params.train_metrics` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "wer", "wer_no_punct"]` | | `training_params.eval_metrics` | List of metrics to compute during validation. | `list` | `["cer", "wer", "wer_no_punct"]` | diff --git a/tests/conftest.py b/tests/conftest.py index d85d11b41ddf81049a5c02b08e3600b87237167f..ffc0dccb232a330f6f27e37d26683e91f051942f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -125,8 +125,6 @@ def training_config(): "lr_schedulers": None, # Learning rate schedulers "eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not "eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training - "focus_metric": "cer", # Metrics to focus on to determine best epoch - "expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value "set_name_focus_metric": "training-val", # Which dataset to focus on to select best weights "train_metrics": [ "loss_ce",