Skip to content
Snippets Groups Projects

Remove growing models

Merged Mélodie Boillet requested to merge remove-growing-models into main
1 file
+ 0
18
Compare changes
  • Side-by-side
  • Inline
+ 0
18
@@ -357,7 +357,6 @@ class GenericTrainingManager:
Load the optimizer of each model
"""
for model_name in self.models.keys():
new_params = dict()
if (
checkpoint
and "optimizer_named_params_{}".format(model_name) in checkpoint
@@ -365,16 +364,6 @@ class GenericTrainingManager:
self.optimizers_named_params_by_group[model_name] = checkpoint[
"optimizer_named_params_{}".format(model_name)
]
# for progressively growing models
for name, param in self.models[model_name].named_parameters():
existing = False
for gr in self.optimizers_named_params_by_group[model_name]:
if name in gr:
gr[name] = param
existing = True
break
if not existing:
new_params.update({name: param})
else:
self.optimizers_named_params_by_group[model_name] = [
dict(),
@@ -420,13 +409,6 @@ class GenericTrainingManager:
checkpoint["lr_scheduler_{}_state_dict".format(model_name)]
)
# for progressively growing models, keeping learning rate
if checkpoint and new_params:
self.optimizers_named_params_by_group[model_name].append(new_params)
self.optimizers[model_name].add_param_group(
{"params": list(new_params.values())}
)
@staticmethod
def set_model_learnable(model, learnable=True):
for p in list(model.parameters()):
Loading