diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index 0c440f49e5f4b55e54a3609802549e8786260905..894e489d75472f02e542479a0b275b19c96d203e 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -224,6 +224,8 @@ class GenericTrainingManager: self.best = checkpoint["best"] if "scaler_state_dict" in checkpoint: self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) + if "dropout_scheduler_step" in checkpoint: + self.dropout_scheduler.resume(checkpoint["dropout_scheduler_step"]) # Load model weights from past training for model_name in self.models: # Transform to DDP/from DDP model @@ -412,6 +414,7 @@ class GenericTrainingManager: "scaler_state_dict": self.scaler.state_dict(), "best": self.best, "charset": self.dataset.charset, + "dropout_scheduler_step": self.dropout_scheduler.step_num, } for model_name in self.optimizers: diff --git a/dan/ocr/schedulers.py b/dan/ocr/schedulers.py index bdabe62980244dfadf11df22147fa239d9fa5d53..89f81ac7f45b1f4f5188d0bd3829707aea80bc7d 100644 --- a/dan/ocr/schedulers.py +++ b/dan/ocr/schedulers.py @@ -17,6 +17,9 @@ class DropoutScheduler: def step(self, num): self.step_num += num + def resume(self, step_num): + self.step_num = step_num + def init_teta_list(self, models): for model_name in models: self.init_teta_list_module(models[model_name])