Skip to content
Snippets Groups Projects

Generate the correct `parameters.yml` file directly during training

Merged Manon Blanco requested to merge create-correct-parameters-file into main
All threads resolved!
4 files
+ 82
59
Compare changes
  • Side-by-side
  • Inline
Files
4
+ 42
29
# -*- coding: utf-8 -*-
import json
import os
import random
from copy import deepcopy
from enum import Enum
from time import time
import numpy as np
@@ -454,37 +453,51 @@ class GenericTrainingManager:
def save_params(self):
"""
Output text file containing a summary of all hyperparameters chosen for the training
Output yaml file containing a summary of all hyperparameters chosen for the training
"""
def compute_nb_params(module):
return sum([np.prod(p.size()) for p in list(module.parameters())])
def class_to_str_dict(my_dict):
for key in my_dict.keys():
if callable(my_dict[key]):
my_dict[key] = my_dict[key].__name__
elif isinstance(my_dict[key], np.ndarray):
my_dict[key] = my_dict[key].tolist()
elif isinstance(my_dict[key], dict):
my_dict[key] = class_to_str_dict(my_dict[key])
return my_dict
path = os.path.join(self.paths["results"], "params")
path = os.path.join(self.paths["results"], "parameters.yml")
if os.path.isfile(path):
return
params = class_to_str_dict(my_dict=deepcopy(self.params))
total_params = 0
for model_name in self.models.keys():
current_params = compute_nb_params(self.models[model_name])
params["model_params"]["models"][model_name] = [
params["model_params"]["models"][model_name],
"{:,}".format(current_params),
]
total_params += current_params
params["model_params"]["total_params"] = "{:,}".format(total_params)
params = {
"parameters": {
"max_char_prediction": self.params["training_params"][
"max_char_prediction"
],
"encoder": {
"dropout": self.params["model_params"]["dropout"],
},
"decoder": {
key: self.params["model_params"][key]
for key in [
"enc_dim",
"l_max",
"h_max",
"w_max",
"dec_num_layers",
"dec_num_heads",
"dec_res_dropout",
"dec_pred_dropout",
"dec_att_dropout",
"dec_dim_feedforward",
"vocab_size",
"attention_win",
]
},
"preprocessings": [
{
key: value.value if isinstance(value, Enum) else value
for key, value in preprocessing.items()
}
for preprocessing in self.params["dataset_params"]["config"].get(
"preprocessings", []
)
],
},
}
with open(path, "w") as f:
json.dump(params, f, indent=4)
yaml.dump(params, f)
def backward_loss(self, loss, retain_graph=False):
self.scaler.scale(loss).backward(retain_graph=retain_graph)
Loading