Skip to content
Snippets Groups Projects
Verified Commit 568e5880 authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Save dropout scheduler step_num

parent b0bca1a7
No related branches found
No related tags found
1 merge request!292Save dropout scheduler step_num
...@@ -224,6 +224,8 @@ class GenericTrainingManager: ...@@ -224,6 +224,8 @@ class GenericTrainingManager:
self.best = checkpoint["best"] self.best = checkpoint["best"]
if "scaler_state_dict" in checkpoint: if "scaler_state_dict" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
if "dropout_scheduler_step" in checkpoint:
self.dropout_scheduler.resume(checkpoint["dropout_scheduler_step"])
# Load model weights from past training # Load model weights from past training
for model_name in self.models: for model_name in self.models:
# Transform to DDP/from DDP model # Transform to DDP/from DDP model
...@@ -412,6 +414,7 @@ class GenericTrainingManager: ...@@ -412,6 +414,7 @@ class GenericTrainingManager:
"scaler_state_dict": self.scaler.state_dict(), "scaler_state_dict": self.scaler.state_dict(),
"best": self.best, "best": self.best,
"charset": self.dataset.charset, "charset": self.dataset.charset,
"dropout_scheduler_step": self.dropout_scheduler.step_num,
} }
for model_name in self.optimizers: for model_name in self.optimizers:
......
...@@ -17,6 +17,9 @@ class DropoutScheduler: ...@@ -17,6 +17,9 @@ class DropoutScheduler:
def step(self, num): def step(self, num):
self.step_num += num self.step_num += num
def resume(self, step_num):
self.step_num = step_num
def init_teta_list(self, models): def init_teta_list(self, models):
for model_name in models: for model_name in models:
self.init_teta_list_module(models[model_name]) self.init_teta_list_module(models[model_name])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment