From 378a86c087e23f495c9cc1252a0f377be0da641b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Tue, 30 May 2023 09:14:54 +0000
Subject: [PATCH] Remove growing models

---
 dan/manager/training.py | 18 ------------------
 1 file changed, 18 deletions(-)

diff --git a/dan/manager/training.py b/dan/manager/training.py
index 1c7f157b..72acce30 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -357,7 +357,6 @@ class GenericTrainingManager:
         Load the optimizer of each model
         """
         for model_name in self.models.keys():
-            new_params = dict()
             if (
                 checkpoint
                 and "optimizer_named_params_{}".format(model_name) in checkpoint
@@ -365,16 +364,6 @@ class GenericTrainingManager:
                 self.optimizers_named_params_by_group[model_name] = checkpoint[
                     "optimizer_named_params_{}".format(model_name)
                 ]
-                # for progressively growing models
-                for name, param in self.models[model_name].named_parameters():
-                    existing = False
-                    for gr in self.optimizers_named_params_by_group[model_name]:
-                        if name in gr:
-                            gr[name] = param
-                            existing = True
-                            break
-                    if not existing:
-                        new_params.update({name: param})
             else:
                 self.optimizers_named_params_by_group[model_name] = [
                     dict(),
@@ -420,13 +409,6 @@ class GenericTrainingManager:
                         checkpoint["lr_scheduler_{}_state_dict".format(model_name)]
                     )
 
-            # for progressively growing models, keeping learning rate
-            if checkpoint and new_params:
-                self.optimizers_named_params_by_group[model_name].append(new_params)
-                self.optimizers[model_name].add_param_group(
-                    {"params": list(new_params.values())}
-                )
-
     @staticmethod
     def set_model_learnable(model, learnable=True):
         for p in list(model.parameters()):
-- 
GitLab