From 378a86c087e23f495c9cc1252a0f377be0da641b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com> Date: Tue, 30 May 2023 09:14:54 +0000 Subject: [PATCH] Remove growing models --- dan/manager/training.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/dan/manager/training.py b/dan/manager/training.py index 1c7f157b..72acce30 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -357,7 +357,6 @@ class GenericTrainingManager: Load the optimizer of each model """ for model_name in self.models.keys(): - new_params = dict() if ( checkpoint and "optimizer_named_params_{}".format(model_name) in checkpoint @@ -365,16 +364,6 @@ class GenericTrainingManager: self.optimizers_named_params_by_group[model_name] = checkpoint[ "optimizer_named_params_{}".format(model_name) ] - # for progressively growing models - for name, param in self.models[model_name].named_parameters(): - existing = False - for gr in self.optimizers_named_params_by_group[model_name]: - if name in gr: - gr[name] = param - existing = True - break - if not existing: - new_params.update({name: param}) else: self.optimizers_named_params_by_group[model_name] = [ dict(), @@ -420,13 +409,6 @@ class GenericTrainingManager: checkpoint["lr_scheduler_{}_state_dict".format(model_name)] ) - # for progressively growing models, keeping learning rate - if checkpoint and new_params: - self.optimizers_named_params_by_group[model_name].append(new_params) - self.optimizers[model_name].add_param_group( - {"params": list(new_params.values())} - ) - @staticmethod def set_model_learnable(model, learnable=True): for p in list(model.parameters()): -- GitLab