Skip to content
Snippets Groups Projects
Commit 5fc3f984 authored by Manon Blanco's avatar Manon Blanco
Browse files

Generate the correct `parameters.yml` file directly during training

parent 2330f7ee
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json
import os import os
import random import random
from copy import deepcopy from enum import Enum
from time import time from time import time
import numpy as np import numpy as np
...@@ -454,37 +453,57 @@ class GenericTrainingManager: ...@@ -454,37 +453,57 @@ class GenericTrainingManager:
def save_params(self): def save_params(self):
""" """
Output text file containing a summary of all hyperparameters chosen for the training Output yaml file containing a summary of all hyperparameters chosen for the training
""" """
path = os.path.join(self.paths["results"], "parameters.yml")
if os.path.isfile(path):
return
def compute_nb_params(module): encoder_keys = ["input_channels", "dropout"]
return sum([np.prod(p.size()) for p in list(module.parameters())]) decoder_keys = [
"enc_dim",
"l_max",
"h_max",
"w_max",
"dec_num_layers",
"dec_num_heads",
"dec_res_dropout",
"dec_pred_dropout",
"dec_att_dropout",
"dec_dim_feedforward",
"use_1d_pe",
"use_2d_pe",
"use_lstm",
"vocab_size",
"attention_win",
]
def class_to_str_dict(my_dict): params = {
for key in my_dict.keys(): "version": "0.0.1",
if callable(my_dict[key]): "parameters": {
my_dict[key] = my_dict[key].__name__ "max_char_prediction": self.params["training_params"][
elif isinstance(my_dict[key], np.ndarray): "max_char_prediction"
my_dict[key] = my_dict[key].tolist() ],
elif isinstance(my_dict[key], dict): "encoder": {
my_dict[key] = class_to_str_dict(my_dict[key]) key: self.params["model_params"][key] for key in encoder_keys
return my_dict },
"decoder": {
key: self.params["model_params"][key] for key in decoder_keys
},
"preprocessings": [
{
key: value.value if isinstance(value, Enum) else value
for key, value in preprocessing.items()
}
for preprocessing in self.params["dataset_params"]["config"].get(
"preprocessings", []
)
],
},
}
path = os.path.join(self.paths["results"], "params")
if os.path.isfile(path):
return
params = class_to_str_dict(my_dict=deepcopy(self.params))
total_params = 0
for model_name in self.models.keys():
current_params = compute_nb_params(self.models[model_name])
params["model_params"]["models"][model_name] = [
params["model_params"]["models"][model_name],
"{:,}".format(current_params),
]
total_params += current_params
params["model_params"]["total_params"] = "{:,}".format(total_params)
with open(path, "w") as f: with open(path, "w") as f:
json.dump(params, f, indent=4) yaml.dump(params, f)
def backward_loss(self, loss, retain_graph=False): def backward_loss(self, loss, retain_graph=False):
self.scaler.scale(loss).backward(retain_graph=retain_graph) self.scaler.scale(loss).backward(retain_graph=retain_graph)
......
...@@ -45,34 +45,4 @@ Once the training is complete, you can apply a trained DAN model on an image. ...@@ -45,34 +45,4 @@ Once the training is complete, you can apply a trained DAN model on an image.
To do this, you will need to: To do this, you will need to:
1. Create a `parameters.yml` file using the parameters saved during training in the `params` file, located in `{training_params.output_folder}/results`. This file should have the following format: 1. Apply a trained DAN model on an image using the [predict command](../usage/predict.md).
```yml
version: 0.0.1
parameters:
max_char_prediction: int
encoder:
input_channels: int
dropout: float
decoder:
enc_dim: int
l_max: int
dec_pred_dropout: float
attention_win: int
use_1d_pe: bool
use_lstm: bool
vocab_size: int
h_max: int
w_max: int
dec_num_layers: int
dec_dim_feedforward: int
dec_num_heads: int
dec_att_dropout: float
dec_res_dropout: float
preprocessings:
- type: str
max_height: int
max_width: int
fixed_height: int
fixed_width: int
```
2. Apply a trained DAN model on an image using the [predict command](../usage/predict.md).
...@@ -9,7 +9,7 @@ from tests.conftest import FIXTURES ...@@ -9,7 +9,7 @@ from tests.conftest import FIXTURES
@pytest.mark.parametrize( @pytest.mark.parametrize(
"expected_best_model_name, expected_last_model_name, training_res, val_res, test_res", "expected_best_model_name, expected_last_model_name, training_res, val_res, test_res, params_res",
( (
( (
"best_0.pt", "best_0.pt",
...@@ -41,6 +41,37 @@ from tests.conftest import FIXTURES ...@@ -41,6 +41,37 @@ from tests.conftest import FIXTURES
"wer_no_punct": 1.0, "wer_no_punct": 1.0,
"nb_samples": 2, "nb_samples": 2,
}, },
{
"version": "0.0.1",
"parameters": {
"max_char_prediction": 30,
"encoder": {"input_channels": 3, "dropout": 0.5},
"decoder": {
"enc_dim": 256,
"l_max": 15000,
"h_max": 500,
"w_max": 1000,
"dec_num_layers": 8,
"dec_num_heads": 4,
"dec_res_dropout": 0.1,
"dec_pred_dropout": 0.1,
"dec_att_dropout": 0.1,
"dec_dim_feedforward": 256,
"use_1d_pe": True,
"use_2d_pe": True,
"use_lstm": False,
"vocab_size": 96,
"attention_win": 100,
},
"preprocessings": [
{
"max_height": 2000,
"max_width": 2000,
"type": "max_resize",
}
],
},
},
), ),
), ),
) )
...@@ -50,6 +81,7 @@ def test_train_and_test( ...@@ -50,6 +81,7 @@ def test_train_and_test(
training_res, training_res,
val_res, val_res,
test_res, test_res,
params_res,
training_config, training_config,
tmp_path, tmp_path,
): ):
...@@ -146,3 +178,13 @@ def test_train_and_test( ...@@ -146,3 +178,13 @@ def test_train_and_test(
if "time" not in metric if "time" not in metric
} }
assert res == expected_res assert res == expected_res
# Check that the parameters file is correct
with (
tmp_path
/ training_config["training_params"]["output_folder"]
/ "results"
/ "parameters.yml"
).open() as f:
res = yaml.safe_load(f)
assert res == params_res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment