Skip to content
Snippets Groups Projects
Unverified Commit 5748aefb authored by Nolan's avatar Nolan Committed by Yoann Schneider
Browse files

Rename teacher_forcing_scheduler to label_noise_scheduler

parent 6a749bed
No related branches found
No related tags found
No related merge requests found
......@@ -1116,26 +1116,26 @@ class Manager(OCRManager):
):
error_rate = self.params["training_params"]["teacher_forcing_error_rate"]
simulated_y_pred, y_len = self.add_label_noise(y, y_len, error_rate)
elif "teacher_forcing_scheduler" in self.params["training_params"]:
elif "label_noise_scheduler" in self.params["training_params"]:
error_rate = (
self.params["training_params"]["teacher_forcing_scheduler"][
self.params["training_params"]["label_noise_scheduler"][
"min_error_rate"
]
+ min(
self.latest_step,
self.params["training_params"]["teacher_forcing_scheduler"][
self.params["training_params"]["label_noise_scheduler"][
"total_num_steps"
],
)
* (
self.params["training_params"]["teacher_forcing_scheduler"][
self.params["training_params"]["label_noise_scheduler"][
"max_error_rate"
]
- self.params["training_params"]["teacher_forcing_scheduler"][
- self.params["training_params"]["label_noise_scheduler"][
"min_error_rate"
]
)
/ self.params["training_params"]["teacher_forcing_scheduler"][
/ self.params["training_params"]["label_noise_scheduler"][
"total_num_steps"
]
)
......
......@@ -223,7 +223,7 @@ def get_config():
"force_cpu": False, # True for debug purposes
"max_char_prediction": 1000, # max number of token prediction
# Keep teacher forcing rate to 20% during whole training
"teacher_forcing_scheduler": {
"label_noise_scheduler": {
"min_error_rate": 0.2,
"max_error_rate": 0.2,
"total_num_steps": 5e4,
......
......@@ -270,9 +270,9 @@ The following configuration can be used by default. It must be defined in `datas
| `training_params.train_metrics` | List of metrics to compute during validation. | `list` | `["cer", "wer", "wer_no_punct"]` |
| `training_params.force_cpu` | Whether to train on CPU (for debugging). | `bool` | `False` |
| `training_params.max_char_prediction` | Maximum number of characters to predict. | `int` | `1000` |
| `training_params.teacher_forcing_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` |
| `training_params.teacher_forcing_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` |
| `training_params.teacher_forcing_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` |
| `training_params.label_noise_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` |
| `training_params.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` |
| `training_params.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` |
## MLFlow logging
......
......@@ -158,7 +158,7 @@ def training_config():
"force_cpu": True, # True for debug purposes
"max_char_prediction": 30, # max number of token prediction
# Keep teacher forcing rate to 20% during whole training
"teacher_forcing_scheduler": {
"label_noise_scheduler": {
"min_error_rate": 0.2,
"max_error_rate": 0.2,
"total_num_steps": 5e4,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment