Skip to content
Snippets Groups Projects

Always use the same dropout scheduler function

Merged Yoann Schneider requested to merge always-use-dropout-scheduler into main
All threads resolved!
Files
2
+ 4
16
@@ -32,7 +32,6 @@ class GenericTrainingManager:
self.type = None
self.is_master = False
self.params = params
self.dropout_scheduler = None
self.models = {}
self.dataset = None
self.dataset_name = list(self.params["dataset_params"]["datasets"].values())[0]
@@ -597,12 +596,6 @@ class GenericTrainingManager:
metric_names = self.params["training_params"]["train_metrics"]
display_values = None
# init curriculum learning
if (
"curriculum_learning" in self.params["training_params"].keys()
and self.params["training_params"]["curriculum_learning"]
):
self.init_curriculum()
# perform epochs
for num_epoch in range(self.latest_epoch + 1, nb_epochs):
# set models trainable
@@ -655,13 +648,11 @@ class GenericTrainingManager:
if "lr" in metric_names:
self.writer.add_scalar(
"lr_{}".format(model_name),
self.lr_schedulers[model_name].lr,
self.lr_schedulers[model_name].step_num,
)
# Update dropout scheduler if used
if self.dropout_scheduler:
self.dropout_scheduler.step(len(batch_metrics["names"]))
self.dropout_scheduler.update_dropout_rate()
# Update dropout scheduler
self.dropout_scheduler.step(len(batch_metrics["names"]))
self.dropout_scheduler.update_dropout_rate()
# Add batch metrics values to epoch metrics values
self.metric_manager["train"].update_metrics(batch_metrics)
@@ -922,9 +913,6 @@ class GenericTrainingManager:
def evaluate_batch(self, batch_data, metric_names):
raise NotImplementedError
def init_curriculum(self):
raise NotImplementedError
def load_save_info(self, info_dict):
"""
Load curriculum info from saved model info
Loading