Skip to content
Snippets Groups Projects
Commit 5fc3f984 authored by Manon Blanco's avatar Manon Blanco
Browse files

Generate the correct `parameters.yml` file directly during training

parent 2330f7ee
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
import json
import os
import random
from copy import deepcopy
from enum import Enum
from time import time
import numpy as np
......@@ -454,37 +453,57 @@ class GenericTrainingManager:
def save_params(self):
"""
Output text file containing a summary of all hyperparameters chosen for the training
Output yaml file containing a summary of all hyperparameters chosen for the training
"""
path = os.path.join(self.paths["results"], "parameters.yml")
if os.path.isfile(path):
return
def compute_nb_params(module):
return sum([np.prod(p.size()) for p in list(module.parameters())])
encoder_keys = ["input_channels", "dropout"]
decoder_keys = [
"enc_dim",
"l_max",
"h_max",
"w_max",
"dec_num_layers",
"dec_num_heads",
"dec_res_dropout",
"dec_pred_dropout",
"dec_att_dropout",
"dec_dim_feedforward",
"use_1d_pe",
"use_2d_pe",
"use_lstm",
"vocab_size",
"attention_win",
]
def class_to_str_dict(my_dict):
for key in my_dict.keys():
if callable(my_dict[key]):
my_dict[key] = my_dict[key].__name__
elif isinstance(my_dict[key], np.ndarray):
my_dict[key] = my_dict[key].tolist()
elif isinstance(my_dict[key], dict):
my_dict[key] = class_to_str_dict(my_dict[key])
return my_dict
params = {
"version": "0.0.1",
"parameters": {
"max_char_prediction": self.params["training_params"][
"max_char_prediction"
],
"encoder": {
key: self.params["model_params"][key] for key in encoder_keys
},
"decoder": {
key: self.params["model_params"][key] for key in decoder_keys
},
"preprocessings": [
{
key: value.value if isinstance(value, Enum) else value
for key, value in preprocessing.items()
}
for preprocessing in self.params["dataset_params"]["config"].get(
"preprocessings", []
)
],
},
}
path = os.path.join(self.paths["results"], "params")
if os.path.isfile(path):
return
params = class_to_str_dict(my_dict=deepcopy(self.params))
total_params = 0
for model_name in self.models.keys():
current_params = compute_nb_params(self.models[model_name])
params["model_params"]["models"][model_name] = [
params["model_params"]["models"][model_name],
"{:,}".format(current_params),
]
total_params += current_params
params["model_params"]["total_params"] = "{:,}".format(total_params)
with open(path, "w") as f:
json.dump(params, f, indent=4)
yaml.dump(params, f)
def backward_loss(self, loss, retain_graph=False):
self.scaler.scale(loss).backward(retain_graph=retain_graph)
......
......@@ -45,34 +45,4 @@ Once the training is complete, you can apply a trained DAN model on an image.
To do this, you will need to:
1. Create a `parameters.yml` file using the parameters saved during training in the `params` file, located in `{training_params.output_folder}/results`. This file should have the following format:
```yml
version: 0.0.1
parameters:
max_char_prediction: int
encoder:
input_channels: int
dropout: float
decoder:
enc_dim: int
l_max: int
dec_pred_dropout: float
attention_win: int
use_1d_pe: bool
use_lstm: bool
vocab_size: int
h_max: int
w_max: int
dec_num_layers: int
dec_dim_feedforward: int
dec_num_heads: int
dec_att_dropout: float
dec_res_dropout: float
preprocessings:
- type: str
max_height: int
max_width: int
fixed_height: int
fixed_width: int
```
2. Apply a trained DAN model on an image using the [predict command](../usage/predict.md).
1. Apply a trained DAN model on an image using the [predict command](../usage/predict.md).
......@@ -9,7 +9,7 @@ from tests.conftest import FIXTURES
@pytest.mark.parametrize(
"expected_best_model_name, expected_last_model_name, training_res, val_res, test_res",
"expected_best_model_name, expected_last_model_name, training_res, val_res, test_res, params_res",
(
(
"best_0.pt",
......@@ -41,6 +41,37 @@ from tests.conftest import FIXTURES
"wer_no_punct": 1.0,
"nb_samples": 2,
},
{
"version": "0.0.1",
"parameters": {
"max_char_prediction": 30,
"encoder": {"input_channels": 3, "dropout": 0.5},
"decoder": {
"enc_dim": 256,
"l_max": 15000,
"h_max": 500,
"w_max": 1000,
"dec_num_layers": 8,
"dec_num_heads": 4,
"dec_res_dropout": 0.1,
"dec_pred_dropout": 0.1,
"dec_att_dropout": 0.1,
"dec_dim_feedforward": 256,
"use_1d_pe": True,
"use_2d_pe": True,
"use_lstm": False,
"vocab_size": 96,
"attention_win": 100,
},
"preprocessings": [
{
"max_height": 2000,
"max_width": 2000,
"type": "max_resize",
}
],
},
},
),
),
)
......@@ -50,6 +81,7 @@ def test_train_and_test(
training_res,
val_res,
test_res,
params_res,
training_config,
tmp_path,
):
......@@ -146,3 +178,13 @@ def test_train_and_test(
if "time" not in metric
}
assert res == expected_res
# Check that the parameters file is correct
with (
tmp_path
/ training_config["training_params"]["output_folder"]
/ "results"
/ "parameters.yml"
).open() as f:
res = yaml.safe_load(f)
assert res == params_res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment