Skip to content
Snippets Groups Projects
Commit 378a86c0 authored by Mélodie Boillet's avatar Mélodie Boillet Committed by Yoann Schneider
Browse files

Remove growing models

parent 46720ba7
No related branches found
No related tags found
1 merge request!141Remove growing models
......@@ -357,7 +357,6 @@ class GenericTrainingManager:
Load the optimizer of each model
"""
for model_name in self.models.keys():
new_params = dict()
if (
checkpoint
and "optimizer_named_params_{}".format(model_name) in checkpoint
......@@ -365,16 +364,6 @@ class GenericTrainingManager:
self.optimizers_named_params_by_group[model_name] = checkpoint[
"optimizer_named_params_{}".format(model_name)
]
# for progressively growing models
for name, param in self.models[model_name].named_parameters():
existing = False
for gr in self.optimizers_named_params_by_group[model_name]:
if name in gr:
gr[name] = param
existing = True
break
if not existing:
new_params.update({name: param})
else:
self.optimizers_named_params_by_group[model_name] = [
dict(),
......@@ -420,13 +409,6 @@ class GenericTrainingManager:
checkpoint["lr_scheduler_{}_state_dict".format(model_name)]
)
# for progressively growing models, keeping learning rate
if checkpoint and new_params:
self.optimizers_named_params_by_group[model_name].append(new_params)
self.optimizers[model_name].add_param_group(
{"params": list(new_params.values())}
)
@staticmethod
def set_model_learnable(model, learnable=True):
for p in list(model.parameters()):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment