From 568e5880824e342b97916de35de293d8733e8f9d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Fri, 13 Oct 2023 12:07:20 +0200
Subject: [PATCH] Save dropout scheduler step_num

---
 dan/ocr/manager/training.py | 3 +++
 dan/ocr/schedulers.py       | 3 +++
 2 files changed, 6 insertions(+)

diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py
index 0c440f49..894e489d 100644
--- a/dan/ocr/manager/training.py
+++ b/dan/ocr/manager/training.py
@@ -224,6 +224,8 @@ class GenericTrainingManager:
         self.best = checkpoint["best"]
         if "scaler_state_dict" in checkpoint:
             self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
+        if "dropout_scheduler_step" in checkpoint:
+            self.dropout_scheduler.resume(checkpoint["dropout_scheduler_step"])
         # Load model weights from past training
         for model_name in self.models:
             # Transform to DDP/from DDP model
@@ -412,6 +414,7 @@ class GenericTrainingManager:
             "scaler_state_dict": self.scaler.state_dict(),
             "best": self.best,
             "charset": self.dataset.charset,
+            "dropout_scheduler_step": self.dropout_scheduler.step_num,
         }
 
         for model_name in self.optimizers:
diff --git a/dan/ocr/schedulers.py b/dan/ocr/schedulers.py
index bdabe629..89f81ac7 100644
--- a/dan/ocr/schedulers.py
+++ b/dan/ocr/schedulers.py
@@ -17,6 +17,9 @@ class DropoutScheduler:
     def step(self, num):
         self.step_num += num
 
+    def resume(self, step_num):
+        self.step_num = step_num
+
     def init_teta_list(self, models):
         for model_name in models:
             self.init_teta_list_module(models[model_name])
-- 
GitLab