diff --git a/dan/manager/training.py b/dan/manager/training.py index c24255be1b0b3fc407bb5d0b69c4cdbca3027ae0..6a72a2451df0197beefdfcd46e26e5cb305f6627 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -1116,26 +1116,26 @@ class Manager(OCRManager): ): error_rate = self.params["training_params"]["teacher_forcing_error_rate"] simulated_y_pred, y_len = self.add_label_noise(y, y_len, error_rate) - elif "teacher_forcing_scheduler" in self.params["training_params"]: + elif "label_noise_scheduler" in self.params["training_params"]: error_rate = ( - self.params["training_params"]["teacher_forcing_scheduler"][ + self.params["training_params"]["label_noise_scheduler"][ "min_error_rate" ] + min( self.latest_step, - self.params["training_params"]["teacher_forcing_scheduler"][ + self.params["training_params"]["label_noise_scheduler"][ "total_num_steps" ], ) * ( - self.params["training_params"]["teacher_forcing_scheduler"][ + self.params["training_params"]["label_noise_scheduler"][ "max_error_rate" ] - - self.params["training_params"]["teacher_forcing_scheduler"][ + - self.params["training_params"]["label_noise_scheduler"][ "min_error_rate" ] ) - / self.params["training_params"]["teacher_forcing_scheduler"][ + / self.params["training_params"]["label_noise_scheduler"][ "total_num_steps" ] ) diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 24216c6485e3bd4090d0c7c1f5bc07edaf7752c3..1d6cfb9a48603288f2e3dea8eab9e9f5ce7de0de 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -223,7 +223,7 @@ def get_config(): "force_cpu": False, # True for debug purposes "max_char_prediction": 1000, # max number of token prediction # Keep teacher forcing rate to 20% during whole training - "teacher_forcing_scheduler": { + "label_noise_scheduler": { "min_error_rate": 0.2, "max_error_rate": 0.2, "total_num_steps": 5e4, diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md index 6a347c43be4d7d8ba13ed343deee6e79e5451cb6..dacfd9dd07bc1f97666cf79b5d4df2e5340cc71f 100644 --- a/docs/usage/train/parameters.md +++ b/docs/usage/train/parameters.md @@ -270,9 +270,9 @@ The following configuration can be used by default. It must be defined in `datas | `training_params.train_metrics` | List of metrics to compute during validation. | `list` | `["cer", "wer", "wer_no_punct"]` | | `training_params.force_cpu` | Whether to train on CPU (for debugging). | `bool` | `False` | | `training_params.max_char_prediction` | Maximum number of characters to predict. | `int` | `1000` | -| `training_params.teacher_forcing_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` | -| `training_params.teacher_forcing_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` | -| `training_params.teacher_forcing_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` | +| `training_params.label_noise_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` | +| `training_params.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` | +| `training_params.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` | ## MLFlow logging diff --git a/tests/conftest.py b/tests/conftest.py index d1bb084c86ba85ea505e434da1896422e586309e..777bdda54579cf43ca1e91c7dd85200781a96d64 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -158,7 +158,7 @@ def training_config(): "force_cpu": True, # True for debug purposes "max_char_prediction": 30, # max number of token prediction # Keep teacher forcing rate to 20% during whole training - "teacher_forcing_scheduler": { + "label_noise_scheduler": { "min_error_rate": 0.2, "max_error_rate": 0.2, "total_num_steps": 5e4,