Skip to content
Snippets Groups Projects

Always use the same dropout scheduler function

Merged Yoann Schneider requested to merge always-use-dropout-scheduler into main
All threads resolved!
Files
5
+ 5
11
@@ -32,7 +32,6 @@ class GenericTrainingManager:
self.type = None
self.is_master = False
self.params = params
self.dropout_scheduler = None
self.models = {}
self.dataset = None
self.dataset_name = list(self.params["dataset_params"]["datasets"].values())[0]
@@ -196,10 +195,7 @@ class GenericTrainingManager:
)
# Handle curriculum dropout
if "dropout_scheduler" in self.params["model_params"]:
func = self.params["model_params"]["dropout_scheduler"]["function"]
T = self.params["model_params"]["dropout_scheduler"]["T"]
self.dropout_scheduler = DropoutScheduler(self.models, func, T)
self.dropout_scheduler = DropoutScheduler(self.models)
self.scaler = GradScaler(enabled=self.params["training_params"]["use_amp"])
@@ -652,13 +648,11 @@ class GenericTrainingManager:
if "lr" in metric_names:
self.writer.add_scalar(
"lr_{}".format(model_name),
self.lr_schedulers[model_name].lr,
self.lr_schedulers[model_name].step_num,
)
# Update dropout scheduler if used
if self.dropout_scheduler:
self.dropout_scheduler.step(len(batch_metrics["names"]))
self.dropout_scheduler.update_dropout_rate()
# Update dropout scheduler
self.dropout_scheduler.step(len(batch_metrics["names"]))
self.dropout_scheduler.update_dropout_rate()
# Add batch metrics values to epoch metrics values
self.metric_manager["train"].update_metrics(batch_metrics)
Loading