From 568e5880824e342b97916de35de293d8733e8f9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com> Date: Fri, 13 Oct 2023 12:07:20 +0200 Subject: [PATCH] Save dropout scheduler step_num --- dan/ocr/manager/training.py | 3 +++ dan/ocr/schedulers.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index 0c440f49..894e489d 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -224,6 +224,8 @@ class GenericTrainingManager: self.best = checkpoint["best"] if "scaler_state_dict" in checkpoint: self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) + if "dropout_scheduler_step" in checkpoint: + self.dropout_scheduler.resume(checkpoint["dropout_scheduler_step"]) # Load model weights from past training for model_name in self.models: # Transform to DDP/from DDP model @@ -412,6 +414,7 @@ class GenericTrainingManager: "scaler_state_dict": self.scaler.state_dict(), "best": self.best, "charset": self.dataset.charset, + "dropout_scheduler_step": self.dropout_scheduler.step_num, } for model_name in self.optimizers: diff --git a/dan/ocr/schedulers.py b/dan/ocr/schedulers.py index bdabe629..89f81ac7 100644 --- a/dan/ocr/schedulers.py +++ b/dan/ocr/schedulers.py @@ -17,6 +17,9 @@ class DropoutScheduler: def step(self, num): self.step_num += num + def resume(self, step_num): + self.step_num = step_num + def init_teta_list(self, models): for model_name in models: self.init_teta_list_module(models[model_name]) -- GitLab