Skip to content
Snippets Groups Projects
Commit 35982ffb authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Remove code for progressively growing models

parent 00eee2e3
No related branches found
No related tags found
No related merge requests found
...@@ -357,7 +357,6 @@ class GenericTrainingManager: ...@@ -357,7 +357,6 @@ class GenericTrainingManager:
Load the optimizer of each model Load the optimizer of each model
""" """
for model_name in self.models.keys(): for model_name in self.models.keys():
new_params = dict()
if ( if (
checkpoint checkpoint
and "optimizer_named_params_{}".format(model_name) in checkpoint and "optimizer_named_params_{}".format(model_name) in checkpoint
...@@ -365,16 +364,6 @@ class GenericTrainingManager: ...@@ -365,16 +364,6 @@ class GenericTrainingManager:
self.optimizers_named_params_by_group[model_name] = checkpoint[ self.optimizers_named_params_by_group[model_name] = checkpoint[
"optimizer_named_params_{}".format(model_name) "optimizer_named_params_{}".format(model_name)
] ]
# for progressively growing models
for name, param in self.models[model_name].named_parameters():
existing = False
for gr in self.optimizers_named_params_by_group[model_name]:
if name in gr:
gr[name] = param
existing = True
break
if not existing:
new_params.update({name: param})
else: else:
self.optimizers_named_params_by_group[model_name] = [ self.optimizers_named_params_by_group[model_name] = [
dict(), dict(),
...@@ -420,13 +409,6 @@ class GenericTrainingManager: ...@@ -420,13 +409,6 @@ class GenericTrainingManager:
checkpoint["lr_scheduler_{}_state_dict".format(model_name)] checkpoint["lr_scheduler_{}_state_dict".format(model_name)]
) )
# for progressively growing models, keeping learning rate
if checkpoint and new_params:
self.optimizers_named_params_by_group[model_name].append(new_params)
self.optimizers[model_name].add_param_group(
{"params": list(new_params.values())}
)
@staticmethod @staticmethod
def set_model_learnable(model, learnable=True): def set_model_learnable(model, learnable=True):
for p in list(model.parameters()): for p in list(model.parameters()):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment