diff --git a/dan/ocr/manager/ocr.py b/dan/ocr/manager/ocr.py
index 5e8c9c2889cb4322e0c2b514628d4415b139368e..31079a0449a191173d15dff580a3ac274afd8f4b 100644
--- a/dan/ocr/manager/ocr.py
+++ b/dan/ocr/manager/ocr.py
@@ -14,11 +14,15 @@ from dan.utils import pad_images, pad_sequences_1D
 
 
 class OCRDatasetManager:
-    def __init__(self, params, device: str):
-        self.params = params
+    def __init__(
+        self, dataset_params: dict, training_params: dict, device: torch.device
+    ):
+        self.params = dataset_params
+        self.training_params = training_params
+        self.device_params = training_params["device"]
 
         # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
-        self.pin_memory = device != "cpu"
+        self.pin_memory = device != torch.device("cpu")
 
         self.train_dataset = None
         self.valid_datasets = dict()
@@ -32,37 +36,27 @@ class OCRDatasetManager:
         self.valid_samplers = dict()
         self.test_samplers = dict()
 
-        self.mean = (
-            np.array(params["config"]["mean"])
-            if "mean" in params["config"].keys()
-            else None
-        )
-        self.std = (
-            np.array(params["config"]["std"])
-            if "std" in params["config"].keys()
-            else None
-        )
+        self.mean = None
+        self.std = None
 
         self.generator = torch.Generator()
         self.generator.manual_seed(0)
 
-        self.load_in_memory = (
-            self.params["config"]["load_in_memory"]
-            if "load_in_memory" in self.params["config"]
-            else True
-        )
+        self.load_in_memory = self.training_params["data"].get("load_in_memory", True)
         self.charset = self.get_charset()
         self.tokens = self.get_tokens()
-        self.params["config"]["padding_token"] = self.tokens["pad"]
+        self.training_params["data"]["padding_token"] = self.tokens["pad"]
 
-        self.my_collate_function = OCRCollateFunction(self.params["config"])
+        self.my_collate_function = OCRCollateFunction(
+            padding_token=training_params["data"]["padding_token"]
+        )
         self.augmentation = (
             get_augmentation_transforms()
-            if self.params["config"]["augmentation"]
+            if self.training_params["data"]["augmentation"]
             else None
         )
         self.preprocessing = get_preprocessing_transforms(
-            params["config"]["preprocessings"], to_pil_image=True
+            training_params["data"]["preprocessings"], to_pil_image=True
         )
 
     def load_datasets(self):
@@ -100,18 +94,18 @@ class OCRDatasetManager:
         """
         Load training and validation data samplers
         """
-        if self.params["use_ddp"]:
+        if self.device_params["use_ddp"]:
             self.train_sampler = DistributedSampler(
                 self.train_dataset,
-                num_replicas=self.params["num_gpu"],
-                rank=self.params["ddp_rank"],
+                num_replicas=self.device_params["nb_gpu"],
+                rank=self.device_params["ddp_rank"],
                 shuffle=True,
             )
             for custom_name in self.valid_datasets.keys():
                 self.valid_samplers[custom_name] = DistributedSampler(
                     self.valid_datasets[custom_name],
-                    num_replicas=self.params["num_gpu"],
-                    rank=self.params["ddp_rank"],
+                    num_replicas=self.device_params["nb_gpu"],
+                    rank=self.device_params["ddp_rank"],
                     shuffle=False,
                 )
         else:
@@ -124,11 +118,12 @@ class OCRDatasetManager:
         """
         self.train_loader = DataLoader(
             self.train_dataset,
-            batch_size=self.params["batch_size"],
+            batch_size=self.training_params["data"]["batch_size"],
             shuffle=True if self.train_sampler is None else False,
             drop_last=False,
             sampler=self.train_sampler,
-            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
+            num_workers=self.device_params["nb_gpu"]
+            * self.training_params["data"]["worker_per_gpu"],
             pin_memory=self.pin_memory,
             collate_fn=self.my_collate_function,
             worker_init_fn=self.seed_worker,
@@ -141,7 +136,8 @@ class OCRDatasetManager:
                 batch_size=1,
                 sampler=self.valid_samplers[key],
                 shuffle=False,
-                num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
+                num_workers=self.device_params["nb_gpu"]
+                * self.training_params["data"]["worker_per_gpu"],
                 pin_memory=self.pin_memory,
                 drop_last=False,
                 collate_fn=self.my_collate_function,
@@ -181,11 +177,11 @@ class OCRDatasetManager:
             std=self.std,
         )
 
-        if self.params["use_ddp"]:
+        if self.device_params["use_ddp"]:
             self.test_samplers[custom_name] = DistributedSampler(
                 self.test_datasets[custom_name],
-                num_replicas=self.params["num_gpu"],
-                rank=self.params["ddp_rank"],
+                num_replicas=self.device_params["nb_gpu"],
+                rank=self.device_params["ddp_rank"],
                 shuffle=False,
             )
         else:
@@ -196,7 +192,8 @@ class OCRDatasetManager:
             batch_size=1,
             sampler=self.test_samplers[custom_name],
             shuffle=False,
-            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
+            num_workers=self.device_params["nb_gpu"]
+            * self.training_params["data"]["worker_per_gpu"],
             pin_memory=self.pin_memory,
             drop_last=False,
             collate_fn=self.my_collate_function,
@@ -245,9 +242,8 @@ class OCRCollateFunction:
     Merge samples data to mini-batch data for OCR task
     """
 
-    def __init__(self, config):
-        self.label_padding_value = config["padding_token"]
-        self.config = config
+    def __init__(self, padding_token):
+        self.label_padding_value = padding_token
 
     def __call__(self, batch_data):
         labels = [batch_data[i]["token_label"] for i in range(len(batch_data))]
diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py
index abff933fdbe802fc109fc237b261e23ffdc59aa8..e14dab8abbf668efb550eb674d2580fe9221f2f7 100644
--- a/dan/ocr/manager/training.py
+++ b/dan/ocr/manager/training.py
@@ -26,6 +26,8 @@ from dan.utils import fix_ddp_layers_names, ind_to_token
 if MLFLOW_AVAILABLE:
     import mlflow
 
+MODEL_NAMES = ("encoder", "decoder")
+
 
 class GenericTrainingManager:
     def __init__(self, params):
@@ -34,7 +36,7 @@ class GenericTrainingManager:
         self.params = params
         self.models = {}
         self.dataset = None
-        self.dataset_name = list(self.params["dataset_params"]["datasets"].values())[0]
+        self.dataset_name = list(self.params["dataset"]["datasets"].values())[0]
         self.paths = None
         self.latest_step = 0
         self.latest_epoch = -1
@@ -48,24 +50,25 @@ class GenericTrainingManager:
         self.writer = None
         self.metric_manager = dict()
 
+        self.device_params = self.params["training"]["device"]
+        self.nb_gpu = (
+            self.device_params["nb_gpu"]
+            if self.device_params["use_ddp"]
+            else torch.cuda.device_count()
+        )
+        # Number of worker that process. Set to the number of GPU available if we are using DDP. Otherwise set to 1.
+        self.nb_workers = self.nb_gpu if self.device_params["use_ddp"] else 1
+        self.tokens = self.params["dataset"].get("tokens")
+
         self.init_hardware_config()
         self.init_paths()
         self.load_dataset()
-        self.params["model_params"]["use_amp"] = self.params["training_params"][
-            "use_amp"
-        ]
-        self.nb_gpu = (
-            self.params["training_params"]["nb_gpu"]
-            if self.params["training_params"]["use_ddp"]
-            else 1
-        )
-        self.tokens = self.params["dataset_params"].get("tokens")
 
     def init_paths(self):
         """
         Create output folders for results and checkpoints
         """
-        output_path = self.params["training_params"]["output_folder"]
+        output_path = self.params["training"]["output_folder"]
         os.makedirs(output_path, exist_ok=True)
         checkpoints_path = os.path.join(output_path, "checkpoints")
         os.makedirs(checkpoints_path, exist_ok=True)
@@ -82,22 +85,10 @@ class GenericTrainingManager:
         """
         Load datasets, data samplers and data loaders
         """
-        self.params["dataset_params"]["use_ddp"] = self.params["training_params"][
-            "use_ddp"
-        ]
-        self.params["dataset_params"]["batch_size"] = self.params["training_params"][
-            "batch_size"
-        ]
-        self.params["dataset_params"]["num_gpu"] = self.params["training_params"][
-            "nb_gpu"
-        ]
-        self.params["dataset_params"]["worker_per_gpu"] = (
-            4
-            if "worker_per_gpu" not in self.params["dataset_params"]
-            else self.params["dataset_params"]["worker_per_gpu"]
-        )
         self.dataset = OCRDatasetManager(
-            self.params["dataset_params"], device=self.device
+            dataset_params=self.params["dataset"],
+            training_params=self.params["training"],
+            device=self.device,
         )
         self.dataset.load_datasets()
         self.dataset.load_ddp_samplers()
@@ -105,55 +96,42 @@ class GenericTrainingManager:
 
     def init_hardware_config(self):
         # Debug mode
-        if self.params["training_params"]["force_cpu"]:
-            self.params["training_params"]["use_ddp"] = False
-            self.params["training_params"]["use_amp"] = False
+        if self.device_params["force_cpu"]:
+            self.device_params["use_ddp"] = False
+            self.device_params["use_amp"] = False
+
         # Manage Distributed Data Parallel & GPU usage
-        self.manual_seed = (
-            1111
-            if "manual_seed" not in self.params["training_params"].keys()
-            else self.params["training_params"]["manual_seed"]
-        )
+        self.manual_seed = self.params["training"].get("manual_seed", 1111)
         self.ddp_config = {
-            "master": self.params["training_params"]["use_ddp"]
-            and self.params["training_params"]["ddp_rank"] == 0,
-            "address": "localhost"
-            if "ddp_addr" not in self.params["training_params"].keys()
-            else self.params["training_params"]["ddp_addr"],
-            "port": "11111"
-            if "ddp_port" not in self.params["training_params"].keys()
-            else self.params["training_params"]["ddp_port"],
-            "backend": "nccl"
-            if "ddp_backend" not in self.params["training_params"].keys()
-            else self.params["training_params"]["ddp_backend"],
-            "rank": self.params["training_params"]["ddp_rank"],
+            "master": self.device_params["use_ddp"]
+            and self.device_params["ddp_rank"] == 0,
+            "address": self.device_params.get("ddp_addr", "localhost"),
+            "port": self.device_params.get("ddp_addr", "11111"),
+            "backend": self.device_params.get("ddp_backend", "nccl"),
+            "rank": self.device_params["ddp_rank"],
         }
-        self.is_master = (
-            self.ddp_config["master"] or not self.params["training_params"]["use_ddp"]
-        )
-        if self.params["training_params"]["force_cpu"]:
-            self.device = "cpu"
+        self.is_master = self.ddp_config["master"] or not self.device_params["use_ddp"]
+        if self.device_params["force_cpu"]:
+            self.device = torch.device("cpu")
         else:
-            if self.params["training_params"]["use_ddp"]:
+            if self.device_params["use_ddp"]:
                 self.device = torch.device(self.ddp_config["rank"])
-                self.params["dataset_params"]["ddp_rank"] = self.ddp_config["rank"]
+                self.device_params["ddp_rank"] = self.ddp_config["rank"]
                 self.launch_ddp()
             else:
                 self.device = torch.device(
                     "cuda:0" if torch.cuda.is_available() else "cpu"
                 )
-        if self.device == "cpu":
-            self.params["model_params"]["device"] = "cpu"
+        if self.device == torch.device("cpu"):
+            self.params["model"]["device"] = "cpu"
         else:
-            self.params["model_params"]["device"] = self.device.type
+            self.params["model"]["device"] = self.device.type
         # Print GPU info
         # global
-        if (
-            self.params["training_params"]["use_ddp"] and self.ddp_config["master"]
-        ) or not self.params["training_params"]["use_ddp"]:
+        if self.ddp_config["master"] or not self.device_params["use_ddp"]:
             print("##################")
-            print("Available GPUS: {}".format(self.params["training_params"]["nb_gpu"]))
-            for i in range(self.params["training_params"]["nb_gpu"]):
+            print("Available GPUS: {}".format(self.nb_gpu))
+            for i in range(self.nb_gpu):
                 print(
                     "Rank {}: {} {}".format(
                         i,
@@ -164,10 +142,10 @@ class GenericTrainingManager:
             print("##################")
         # local
         print("Local GPU:")
-        if self.device != "cpu":
+        if self.device != torch.device("cpu"):
             print(
                 "Rank {}: {} {}".format(
-                    self.params["training_params"]["ddp_rank"],
+                    self.device_params["ddp_rank"],
                     torch.cuda.get_device_name(),
                     torch.cuda.get_device_properties(self.device),
                 )
@@ -180,14 +158,20 @@ class GenericTrainingManager:
         """
         Load model weights from scratch or from checkpoints
         """
-        # Instantiate Model
-        for model_name in self.params["model_params"]["models"].keys():
-            self.models[model_name] = self.params["model_params"]["models"][model_name](
-                self.params["model_params"]
-            )
+        common_params = {
+            "h_max": self.params["model"].get("h_max"),
+            "w_max": self.params["model"].get("w_max"),
+            "device": self.device,
+            "vocab_size": self.params["model"]["vocab_size"],
+        }
+        # Instantiate encoder, decoder
+        for model_name in MODEL_NAMES:
+            params = self.params["model"][model_name]
+            model_class = params.get("class")
+            self.models[model_name] = model_class({**params, **common_params})
             self.models[model_name].to(self.device)  # To GPU or CPU
             # make the model compatible with Distributed Data Parallel if used
-            if self.params["training_params"]["use_ddp"]:
+            if self.device_params["use_ddp"]:
                 self.models[model_name] = DDP(
                     self.models[model_name],
                     [self.ddp_config["rank"]],
@@ -197,7 +181,7 @@ class GenericTrainingManager:
         # Handle curriculum dropout
         self.dropout_scheduler = DropoutScheduler(self.models)
 
-        self.scaler = GradScaler(enabled=self.params["training_params"]["use_amp"])
+        self.scaler = GradScaler(enabled=self.device_params["use_amp"])
 
         # Check if checkpoint exists
         checkpoint = self.get_checkpoint()
@@ -215,9 +199,9 @@ class GenericTrainingManager:
         """
         Seek if checkpoint exist, return None otherwise
         """
-        if self.params["training_params"]["load_epoch"] in ("best", "last"):
+        if self.params["training"]["load_epoch"] in ("best", "last"):
             for filename in os.listdir(self.paths["checkpoints"]):
-                if self.params["training_params"]["load_epoch"] in filename:
+                if self.params["training"]["load_epoch"] in filename:
                     return torch.load(
                         os.path.join(self.paths["checkpoints"], filename),
                         map_location=self.device,
@@ -240,7 +224,7 @@ class GenericTrainingManager:
             # Transform to DDP/from DDP model
             checkpoint[f"{model_name}_state_dict"] = fix_ddp_layers_names(
                 checkpoint[f"{model_name}_state_dict"],
-                self.params["training_params"]["use_ddp"],
+                self.device_params["use_ddp"],
             )
 
             self.models[model_name].load_state_dict(
@@ -259,10 +243,10 @@ class GenericTrainingManager:
                 pass
 
         # Handle transfer learning instructions
-        if self.params["model_params"]["transfer_learning"]:
+        if self.params["training"]["transfer_learning"]:
             # Iterates over models
-            for model_name in self.params["model_params"]["transfer_learning"].keys():
-                state_dict_name, path, learnable, strict = self.params["model_params"][
+            for model_name in self.params["training"]["transfer_learning"]:
+                state_dict_name, path, learnable, strict = self.params["training"][
                     "transfer_learning"
                 ][model_name]
 
@@ -271,7 +255,7 @@ class GenericTrainingManager:
                 # Transform to DDP/from DDP model
                 checkpoint[f"{model_name}_state_dict"] = fix_ddp_layers_names(
                     checkpoint[f"{model_name}_state_dict"],
-                    self.params["training_params"]["use_ddp"],
+                    self.device_params["use_ddp"],
                 )
 
                 try:
@@ -293,8 +277,8 @@ class GenericTrainingManager:
                             # for pre-training of decision layer
                             if (
                                 "end_conv" in key
-                                and "transfered_charset" in self.params["model_params"]
-                                and self.params["model_params"]["transfered_charset"]
+                                and "transfered_charset" in self.params["model"]
+                                and self.params["model"]["transfered_charset"]
                             ):
                                 self.adapt_decision_layer_to_old_charset(
                                     model_name, key, checkpoint, state_dict_name
@@ -325,13 +309,13 @@ class GenericTrainingManager:
         weights = checkpoint["{}_state_dict".format(state_dict_name)][key]
         new_size = list(weights.size())
         new_size[0] = (
-            len(self.dataset.charset) + self.params["model_params"]["additional_tokens"]
+            len(self.dataset.charset) + self.params["model"]["additional_tokens"]
         )
         new_weights = torch.zeros(new_size, device=weights.device, dtype=weights.dtype)
         old_charset = (
             checkpoint["charset"]
             if "charset" in checkpoint
-            else self.params["model_params"]["old_charset"]
+            else self.params["model"]["old_charset"]
         )
         if "bias" not in key:
             kaiming_uniform_(new_weights, nonlinearity="relu")
@@ -374,21 +358,18 @@ class GenericTrainingManager:
             self.reset_optimizer(model_name)
 
             # Handle learning rate schedulers
-            if (
-                "lr_schedulers" in self.params["training_params"]
-                and self.params["training_params"]["lr_schedulers"]
-            ):
+            if self.params["training"].get("lr_schedulers"):
                 key = (
                     "all"
-                    if "all" in self.params["training_params"]["lr_schedulers"]
+                    if "all" in self.params["training"]["lr_schedulers"]
                     else model_name
                 )
-                if key in self.params["training_params"]["lr_schedulers"]:
-                    self.lr_schedulers[model_name] = self.params["training_params"][
+                if key in self.params["training"]["lr_schedulers"]:
+                    self.lr_schedulers[model_name] = self.params["training"][
                         "lr_schedulers"
                     ][key]["class"](
                         self.optimizers[model_name],
-                        **self.params["training_params"]["lr_schedulers"][key]["args"],
+                        **self.params["training"]["lr_schedulers"][key]["args"],
                     )
 
             # Load optimizer state from past training
@@ -398,8 +379,8 @@ class GenericTrainingManager:
                 )
                 # Load optimizer scheduler config from past training if used
                 if (
-                    "lr_schedulers" in self.params["training_params"]
-                    and self.params["training_params"]["lr_schedulers"]
+                    "lr_schedulers" in self.params["training"]
+                    and self.params["training"]["lr_schedulers"]
                     and "lr_scheduler_{}_state_dict".format(model_name)
                     in checkpoint.keys()
                 ):
@@ -455,14 +436,10 @@ class GenericTrainingManager:
         Reset optimizer learning rate for given model
         """
         params = list(self.optimizers_named_params_by_group[model_name][0].values())
-        key = (
-            "all"
-            if "all" in self.params["training_params"]["optimizers"]
-            else model_name
-        )
-        self.optimizers[model_name] = self.params["training_params"]["optimizers"][key][
+        key = "all" if "all" in self.params["training"]["optimizers"] else model_name
+        self.optimizers[model_name] = self.params["training"]["optimizers"][key][
             "class"
-        ](params, **self.params["training_params"]["optimizers"][key]["args"])
+        ](params, **self.params["training"]["optimizers"][key]["args"])
         for i in range(1, len(self.optimizers_named_params_by_group[model_name])):
             self.optimizers[model_name].add_param_group(
                 {
@@ -478,7 +455,7 @@ class GenericTrainingManager:
         and a yaml file containing parameters used for inference
         """
 
-        def compute_nb_params(module):
+        def compute_nb_params(module) -> np.int64:
             return sum([np.prod(p.size()) for p in list(module.parameters())])
 
         def class_to_str_dict(my_dict):
@@ -509,14 +486,11 @@ class GenericTrainingManager:
             return
         params = class_to_str_dict(my_dict=deepcopy(self.params))
         total_params = 0
-        for model_name in self.models.keys():
-            current_params = compute_nb_params(self.models[model_name])
-            params["model_params"]["models"][model_name] = [
-                params["model_params"]["models"][model_name],
-                "{:,}".format(current_params),
-            ]
+        for model_name in MODEL_NAMES:
+            current_params = int(compute_nb_params(self.models[model_name]))
+            params["model"][model_name]["nb_params"] = current_params
             total_params += current_params
-        params["model_params"]["total_params"] = "{:,}".format(total_params)
+        params["model"]["total_params"] = "{:,}".format(total_params)
         params["mean"] = self.dataset.mean.tolist()
         params["std"] = self.dataset.std.tolist()
         with open(path, "w") as f:
@@ -526,32 +500,35 @@ class GenericTrainingManager:
         path = os.path.join(self.paths["results"], "inference_parameters.yml")
         if os.path.isfile(path):
             return
+
+        decoder_params = {
+            key: params["model"]["decoder"][key]
+            for key in (
+                "l_max",
+                "dec_num_layers",
+                "dec_num_heads",
+                "dec_res_dropout",
+                "dec_pred_dropout",
+                "dec_att_dropout",
+                "dec_dim_feedforward",
+                "attention_win",
+                "enc_dim",
+            )
+        }
+
         inference_params = {
             "parameters": {
                 "mean": params["mean"],
                 "std": params["std"],
-                "max_char_prediction": params["training_params"]["max_char_prediction"],
-                "encoder": {
-                    "dropout": params["model_params"]["dropout"],
-                },
+                "max_char_prediction": params["dataset"]["max_char_prediction"],
+                "encoder": {"dropout": params["model"]["encoder"]["dropout"]},
                 "decoder": {
-                    key: params["model_params"][key]
-                    for key in [
-                        "enc_dim",
-                        "l_max",
-                        "h_max",
-                        "w_max",
-                        "dec_num_layers",
-                        "dec_num_heads",
-                        "dec_res_dropout",
-                        "dec_pred_dropout",
-                        "dec_att_dropout",
-                        "dec_dim_feedforward",
-                        "vocab_size",
-                        "attention_win",
-                    ]
+                    "h_max": params["model"]["h_max"],
+                    "w_max": params["model"]["w_max"],
+                    "vocab_size": params["model"]["vocab_size"],
+                    **decoder_params,
                 },
-                "preprocessings": params["dataset_params"]["config"]["preprocessings"],
+                "preprocessings": params["training"]["data"]["preprocessings"],
             },
         }
         with open(path, "w") as f:
@@ -565,14 +542,13 @@ class GenericTrainingManager:
             if names and model_name not in names:
                 continue
             if (
-                "gradient_clipping" in self.params["training_params"]
-                and model_name
-                in self.params["training_params"]["gradient_clipping"]["models"]
+                self.params["training"].get("gradient_clipping")
+                and model_name in self.params["training"]["gradient_clipping"]["models"]
             ):
                 self.scaler.unscale_(self.optimizers[model_name])
                 torch.nn.utils.clip_grad_norm_(
                     self.models[model_name].parameters(),
-                    self.params["training_params"]["gradient_clipping"]["max"],
+                    self.params["training"]["gradient_clipping"]["max"],
                 )
             self.scaler.step(self.optimizers[model_name])
         self.scaler.update()
@@ -592,8 +568,8 @@ class GenericTrainingManager:
             self.writer = SummaryWriter(self.paths["results"])
             self.save_params()
         # init variables
-        nb_epochs = self.params["training_params"]["max_nb_epochs"]
-        metric_names = self.params["training_params"]["train_metrics"]
+        nb_epochs = self.params["training"]["max_nb_epochs"]
+        metric_names = self.params["training"]["metrics"]["train"]
 
         display_values = None
         # perform epochs
@@ -623,21 +599,20 @@ class GenericTrainingManager:
                     )
                     batch_metrics["names"] = batch_data["names"]
                     # Merge metrics if Distributed Data Parallel is used
-                    if self.params["training_params"]["use_ddp"]:
+                    if self.device_params["use_ddp"]:
                         batch_metrics = self.merge_ddp_metrics(batch_metrics)
                     # Update learning rate via scheduler if one is used
-                    if self.params["training_params"]["lr_schedulers"]:
+                    if self.params["training"]["lr_schedulers"]:
                         for model_name in self.models:
                             key = (
                                 "all"
-                                if "all"
-                                in self.params["training_params"]["lr_schedulers"]
+                                if "all" in self.params["training"]["lr_schedulers"]
                                 else model_name
                             )
                             if (
                                 model_name in self.lr_schedulers
                                 and ind_batch
-                                % self.params["training_params"]["lr_schedulers"][key][
+                                % self.params["training"]["lr_schedulers"][key][
                                     "step_interval"
                                 ]
                                 == 0
@@ -658,7 +633,7 @@ class GenericTrainingManager:
                     self.metric_manager["train"].update_metrics(batch_metrics)
                     display_values = self.metric_manager["train"].get_display_values()
                     pbar.set_postfix(values=str(display_values))
-                    pbar.update(len(batch_data["names"]) * self.nb_gpu)
+                    pbar.update(len(batch_data["names"]) * self.nb_workers)
 
                 # Log MLflow metrics
                 logging_metrics(
@@ -669,17 +644,16 @@ class GenericTrainingManager:
                 # log metrics in tensorboard file
                 for key in display_values.keys():
                     self.writer.add_scalar(
-                        "{}_{}".format(
-                            self.params["dataset_params"]["train"]["name"], key
-                        ),
+                        "{}_{}".format(self.params["dataset"]["train"]["name"], key),
                         display_values[key],
                         num_epoch,
                     )
 
             # evaluate and compute metrics for valid sets
             if (
-                self.params["training_params"]["eval_on_valid"]
-                and num_epoch % self.params["training_params"]["eval_on_valid_interval"]
+                self.params["training"]["validation"]["eval_on_valid"]
+                and num_epoch
+                % self.params["training"]["validation"]["eval_on_valid_interval"]
                 == 0
             ):
                 for valid_set_name in self.dataset.valid_loaders.keys():
@@ -695,7 +669,7 @@ class GenericTrainingManager:
                                 eval_values[key],
                                 num_epoch,
                             )
-                        if valid_set_name == self.params["training_params"][
+                        if valid_set_name == self.params["training"]["validation"][
                             "set_name_focus_metric"
                         ] and (self.best is None or eval_values["cer"] <= self.best):
                             self.save_model(epoch=num_epoch, name="best")
@@ -706,8 +680,8 @@ class GenericTrainingManager:
                 self.check_and_update_curriculum()
 
             if (
-                "curriculum_model" in self.params["model_params"]
-                and self.params["model_params"]["curriculum_model"]
+                "curriculum_model" in self.params["model"]
+                and self.params["model"]["curriculum_model"]
             ):
                 self.update_curriculum_model()
 
@@ -724,7 +698,7 @@ class GenericTrainingManager:
         # Set models in eval mode
         for model_name in self.models.keys():
             self.models[model_name].eval()
-        metric_names = self.params["training_params"]["eval_metrics"]
+        metric_names = self.params["training"]["metrics"]["eval"]
         display_values = None
 
         # initialize epoch metrics
@@ -745,7 +719,7 @@ class GenericTrainingManager:
                     )
                     batch_metrics["names"] = batch_data["names"]
                     # merge metrics values if Distributed Data Parallel is used
-                    if self.params["training_params"]["use_ddp"]:
+                    if self.device_params["use_ddp"]:
                         batch_metrics = self.merge_ddp_metrics(batch_metrics)
 
                     # add batch metrics to epoch metrics
@@ -753,7 +727,7 @@ class GenericTrainingManager:
                     display_values = self.metric_manager[set_name].get_display_values()
 
                     pbar.set_postfix(values=str(display_values))
-                    pbar.update(len(batch_data["names"]) * self.nb_gpu)
+                    pbar.update(len(batch_data["names"]) * self.nb_workers)
 
                 # log metrics in MLflow
                 logging_metrics(
@@ -797,7 +771,7 @@ class GenericTrainingManager:
                     )
                     batch_metrics["names"] = batch_data["names"]
                     # merge batch metrics if Distributed Data Parallel is used
-                    if self.params["training_params"]["use_ddp"]:
+                    if self.device_params["use_ddp"]:
                         batch_metrics = self.merge_ddp_metrics(batch_metrics)
 
                     # add batch metrics to epoch metrics
@@ -807,7 +781,7 @@ class GenericTrainingManager:
                     ].get_display_values()
 
                     pbar.set_postfix(values=str(display_values))
-                    pbar.update(len(batch_data["names"]) * self.nb_gpu)
+                    pbar.update(len(batch_data["names"]) * self.nb_workers)
 
                 # log metrics in MLflow
                 logging_name = custom_name.split("-")[1]
@@ -849,7 +823,7 @@ class GenericTrainingManager:
         dist.init_process_group(
             self.ddp_config["backend"],
             rank=self.ddp_config["rank"],
-            world_size=self.params["training_params"]["nb_gpu"],
+            world_size=self.nb_gpu,
         )
         torch.cuda.set_device(self.ddp_config["rank"])
         random.seed(self.manual_seed)
@@ -933,7 +907,7 @@ class GenericTrainingManager:
 class OCRManager(GenericTrainingManager):
     def __init__(self, params):
         super(OCRManager, self).__init__(params)
-        self.params["model_params"]["vocab_size"] = len(self.dataset.charset)
+        self.params["model"]["vocab_size"] = len(self.dataset.charset)
 
 
 class Manager(OCRManager):
@@ -971,34 +945,24 @@ class Manager(OCRManager):
         reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]]
         y_len = batch_data["labels_len"]
 
-        if "label_noise_scheduler" in self.params["training_params"]:
+        if "label_noise_scheduler" in self.params["training"]:
             error_rate = (
-                self.params["training_params"]["label_noise_scheduler"][
-                    "min_error_rate"
-                ]
+                self.params["training"]["label_noise_scheduler"]["min_error_rate"]
                 + min(
                     self.latest_step,
-                    self.params["training_params"]["label_noise_scheduler"][
-                        "total_num_steps"
-                    ],
+                    self.params["training"]["label_noise_scheduler"]["total_num_steps"],
                 )
                 * (
-                    self.params["training_params"]["label_noise_scheduler"][
-                        "max_error_rate"
-                    ]
-                    - self.params["training_params"]["label_noise_scheduler"][
-                        "min_error_rate"
-                    ]
+                    self.params["training"]["label_noise_scheduler"]["max_error_rate"]
+                    - self.params["training"]["label_noise_scheduler"]["min_error_rate"]
                 )
-                / self.params["training_params"]["label_noise_scheduler"][
-                    "total_num_steps"
-                ]
+                / self.params["training"]["label_noise_scheduler"]["total_num_steps"]
             )
             simulated_y_pred, y_len = self.add_label_noise(y, y_len, error_rate)
         else:
             simulated_y_pred = y
 
-        with autocast(enabled=self.params["training_params"]["use_amp"]):
+        with autocast(enabled=self.device_params["use_amp"]):
             hidden_predict = None
             cache = None
 
@@ -1006,7 +970,7 @@ class Manager(OCRManager):
             features_size = raw_features.size()
             b, c, h, w = features_size
 
-            if self.params["training_params"]["use_ddp"]:
+            if self.device_params["use_ddp"]:
                 pos_features = self.models[
                     "decoder"
                 ].module.features_updater.get_pos_features(raw_features)
@@ -1061,10 +1025,10 @@ class Manager(OCRManager):
         x = batch_data["imgs"].to(self.device)
         reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]]
 
-        max_chars = self.params["training_params"]["max_char_prediction"]
+        max_chars = self.params["dataset"]["max_char_prediction"]
 
         start_time = time()
-        with autocast(enabled=self.params["training_params"]["use_amp"]):
+        with autocast(enabled=self.device_params["use_amp"]):
             b = x.size(0)
             reached_end = torch.zeros((b,), dtype=torch.bool, device=self.device)
             prediction_len = torch.zeros((b,), dtype=torch.int, device=self.device)
@@ -1106,7 +1070,7 @@ class Manager(OCRManager):
             else:
                 features = self.models["encoder"](x)
             features_size = features.size()
-            if self.params["training_params"]["use_ddp"]:
+            if self.device_params["use_ddp"]:
                 pos_features = self.models[
                     "decoder"
                 ].module.features_updater.get_pos_features(features)
diff --git a/dan/ocr/train.py b/dan/ocr/train.py
index 7c44e527450a85d2434cd832b0f81932991be48e..1469f19b3bcca8323b35ae494751537b9453b725 100644
--- a/dan/ocr/train.py
+++ b/dan/ocr/train.py
@@ -34,7 +34,7 @@ def train_and_test(rank, params, mlflow_logging=False):
     torch.backends.cudnn.benchmark = False
     torch.backends.cudnn.deterministic = True
 
-    params["training_params"]["ddp_rank"] = rank
+    params["training"]["device"]["ddp_rank"] = rank
     model = Manager(params)
     model.load_model()
 
@@ -44,11 +44,11 @@ def train_and_test(rank, params, mlflow_logging=False):
     model.train(mlflow_logging=mlflow_logging)
 
     # load weights giving best CER on valid set
-    model.params["training_params"]["load_epoch"] = "best"
+    model.params["training"]["load_epoch"] = "best"
     model.load_model()
 
     metrics = ["cer", "wer", "wer_no_punct", "time"]
-    for dataset_name in params["dataset_params"]["datasets"].keys():
+    for dataset_name in params["dataset"]["datasets"].keys():
         for set_name in ["test", "val", "train"]:
             model.predict(
                 "{}-{}".format(dataset_name, set_name),
@@ -79,7 +79,7 @@ def get_config():
             "aws_access_key_id": "",
             "aws_secret_access_key": "",
         },
-        "dataset_params": {
+        "dataset": {
             "datasets": {
                 dataset_name: "{}/{}_{}{}".format(
                     dataset_path, dataset_name, dataset_level, dataset_variant
@@ -101,7 +101,35 @@ def get_config():
                     (dataset_name, "test"),
                 ],
             },
-            "config": {
+            "max_char_prediction": 1000,  # max number of token prediction
+            "tokens": None,
+        },
+        "model": {
+            "transfered_charset": True,  # Transfer learning of the decision layer based on charset of the line HTR model
+            "additional_tokens": 1,  # for decision layer = [<eot>, ], only for transferred charset
+            "encoder": {
+                "class": FCN_Encoder,
+                "dropout": 0.5,  # dropout rate for encoder
+                "nb_layers": 5,  # encoder
+            },
+            "h_max": 500,  # maximum height for encoder output (for 2D positional embedding)
+            "w_max": 1000,  # maximum width for encoder output (for 2D positional embedding)
+            "decoder": {
+                "class": GlobalHTADecoder,
+                "l_max": 15000,  # max predicted sequence (for 1D positional embedding)
+                "dec_num_layers": 8,  # number of transformer decoder layers
+                "dec_num_heads": 4,  # number of heads in transformer decoder layers
+                "dec_res_dropout": 0.1,  # dropout in transformer decoder layers
+                "dec_pred_dropout": 0.1,  # dropout rate before decision layer
+                "dec_att_dropout": 0.1,  # dropout rate in multi head attention
+                "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
+                "attention_win": 100,  # length of attention window
+                "enc_dim": 256,  # dimension of extracted features
+            },
+        },
+        "training": {
+            "data": {
+                "batch_size": 2,  # mini-batch size for training
                 "load_in_memory": True,  # Load all images in CPU memory
                 "worker_per_gpu": 4,  # Num of parallel processes per gpu for data loading
                 "preprocessings": [
@@ -113,54 +141,36 @@ def get_config():
                 ],
                 "augmentation": True,
             },
-            "tokens": None,
-        },
-        "model_params": {
-            "models": {
-                "encoder": FCN_Encoder,
-                "decoder": GlobalHTADecoder,
+            "device": {
+                "use_ddp": False,  # Use DistributedDataParallel
+                "ddp_port": "20027",
+                "use_amp": True,  # Enable automatic mix-precision
+                "nb_gpu": torch.cuda.device_count(),
+                "force_cpu": False,  # True for debug purposes
             },
-            # "transfer_learning": None,
-            "transfer_learning": {
-                # model_name: [state_dict_name, checkpoint_path, learnable, strict]
-                "encoder": [
-                    "encoder",
-                    "pretrained_models/dan_rimes_page.pt",
-                    True,
-                    True,
-                ],
-                "decoder": [
-                    "decoder",
-                    "pretrained_models/dan_rimes_page.pt",
-                    True,
-                    False,
-                ],
+            "metrics": {
+                "train": [
+                    "loss_ce",
+                    "cer",
+                    "wer",
+                    "wer_no_punct",
+                ],  # Metrics name for training
+                "eval": [
+                    "cer",
+                    "wer",
+                    "wer_no_punct",
+                ],  # Metrics name for evaluation on validation set during training
+            },
+            "validation": {
+                "eval_on_valid": True,  # Whether to eval and logs metrics on validation set during training or not
+                "eval_on_valid_interval": 5,  # Interval (in epochs) to evaluate during training
+                "set_name_focus_metric": "{}-val".format(
+                    dataset_name
+                ),  # Which dataset to focus on to select best weights
             },
-            "transfered_charset": True,  # Transfer learning of the decision layer based on charset of the line HTR model
-            "additional_tokens": 1,  # for decision layer = [<eot>, ], only for transferred charset
-            "dropout": 0.5,  # dropout rate for encoder
-            "enc_dim": 256,  # dimension of extracted features
-            "nb_layers": 5,  # encoder
-            "h_max": 500,  # maximum height for encoder output (for 2D positional embedding)
-            "w_max": 1000,  # maximum width for encoder output (for 2D positional embedding)
-            "l_max": 15000,  # max predicted sequence (for 1D positional embedding)
-            "dec_num_layers": 8,  # number of transformer decoder layers
-            "dec_num_heads": 4,  # number of heads in transformer decoder layers
-            "dec_res_dropout": 0.1,  # dropout in transformer decoder layers
-            "dec_pred_dropout": 0.1,  # dropout rate before decision layer
-            "dec_att_dropout": 0.1,  # dropout rate in multi head attention
-            "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
-            "attention_win": 100,  # length of attention window
-        },
-        "training_params": {
             "output_folder": "outputs/dan_esposalles_record",  # folder name for checkpoint and results
             "max_nb_epochs": 800,  # maximum number of epochs before to stop
             "load_epoch": "last",  # ["best", "last"]: last to continue training, best to evaluate
-            "batch_size": 2,  # mini-batch size for training
-            "use_ddp": False,  # Use DistributedDataParallel
-            "ddp_port": "20027",
-            "use_amp": True,  # Enable automatic mix-precision
-            "nb_gpu": torch.cuda.device_count(),
             "optimizers": {
                 "all": {
                     "class": Adam,
@@ -171,30 +181,28 @@ def get_config():
                 },
             },
             "lr_schedulers": None,  # Learning rate schedulers
-            "eval_on_valid": True,  # Whether to eval and logs metrics on validation set during training or not
-            "eval_on_valid_interval": 5,  # Interval (in epochs) to evaluate during training
-            "set_name_focus_metric": "{}-val".format(
-                dataset_name
-            ),  # Which dataset to focus on to select best weights
-            "train_metrics": [
-                "loss_ce",
-                "cer",
-                "wer",
-                "wer_no_punct",
-            ],  # Metrics name for training
-            "eval_metrics": [
-                "cer",
-                "wer",
-                "wer_no_punct",
-            ],  # Metrics name for evaluation on validation set during training
-            "force_cpu": False,  # True for debug purposes
-            "max_char_prediction": 1000,  # max number of token prediction
             # Keep teacher forcing rate to 20% during whole training
             "label_noise_scheduler": {
                 "min_error_rate": 0.2,
                 "max_error_rate": 0.2,
                 "total_num_steps": 5e4,
             },
+            # "transfer_learning": None,
+            "transfer_learning": {
+                # model_name: [state_dict_name, checkpoint_path, learnable, strict]
+                "encoder": [
+                    "encoder",
+                    "pretrained_models/dan_rimes_page.pt",
+                    True,
+                    True,
+                ],
+                "decoder": [
+                    "decoder",
+                    "pretrained_models/dan_rimes_page.pt",
+                    True,
+                    False,
+                ],
+            },
         },
     }
 
@@ -218,22 +226,22 @@ def serialize_config(config):
     serialized_config["mlflow"]["aws_secret_access_key"] = ""
 
     # Get the name of the class
-    serialized_config["model_params"]["models"]["encoder"] = serialized_config[
-        "model_params"
-    ]["models"]["encoder"].__name__
-    serialized_config["model_params"]["models"]["decoder"] = serialized_config[
-        "model_params"
-    ]["models"]["decoder"].__name__
-    serialized_config["training_params"]["optimizers"]["all"][
-        "class"
-    ] = serialized_config["training_params"]["optimizers"]["all"]["class"].__name__
+    serialized_config["model"]["models"]["encoder"] = serialized_config["model"][
+        "models"
+    ]["encoder"].__name__
+    serialized_config["model"]["models"]["decoder"] = serialized_config["model"][
+        "models"
+    ]["decoder"].__name__
+    serialized_config["training"]["optimizers"]["all"]["class"] = serialized_config[
+        "training"
+    ]["optimizers"]["all"]["class"].__name__
 
     # Cast the functions to str
-    serialized_config["dataset_params"]["config"]["augmentation"] = str(
-        serialized_config["dataset_params"]["config"]["augmentation"]
+    serialized_config["dataset"]["config"]["augmentation"] = str(
+        serialized_config["dataset"]["config"]["augmentation"]
     )
-    serialized_config["training_params"]["nb_gpu"] = str(
-        serialized_config["training_params"]["nb_gpu"]
+    serialized_config["training"]["nb_gpu"] = str(
+        serialized_config["training"]["nb_gpu"]
     )
 
     return serialized_config
@@ -241,13 +249,13 @@ def serialize_config(config):
 
 def start_training(config, mlflow_logging: bool) -> None:
     if (
-        config["training_params"]["use_ddp"]
-        and not config["training_params"]["force_cpu"]
+        config["training"]["device"]["use_ddp"]
+        and not config["training"]["device"]["force_cpu"]
     ):
         mp.spawn(
             train_and_test,
             args=(config, mlflow_logging),
-            nprocs=config["training_params"]["nb_gpu"],
+            nprocs=config["training"]["device"]["nb_gpu"],
         )
     else:
         train_and_test(0, config, mlflow_logging)
@@ -269,9 +277,7 @@ def run():
     if "mlflow" not in config:
         start_training(config, mlflow_logging=False)
     else:
-        labels_path = (
-            Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
-        )
+        labels_path = Path(config["dataset"]["datasets"][dataset_name]) / "labels.json"
         with start_mlflow_run(config["mlflow"]) as (run, created):
             if created:
                 logger.info(f"Started MLflow run with ID ({run.info.run_id})")
diff --git a/tests/conftest.py b/tests/conftest.py
index 3ec0ea4eefd4d180b8f920b8fc939ed4275136b8..3abc8c71d7bb355bf3896b88c77d68fd06ca073c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -28,7 +28,7 @@ def demo_db(database_path):
 @pytest.fixture
 def training_config():
     return {
-        "dataset_params": {
+        "dataset": {
             "datasets": {
                 "training": "./tests/data/training/training_dataset",
             },
@@ -48,48 +48,75 @@ def training_config():
                     ("training", "test"),
                 ],
             },
-            "config": {
+            "max_char_prediction": 30,  # max number of token prediction
+            "tokens": None,
+        },
+        "model": {
+            "transfered_charset": True,  # Transfer learning of the decision layer based on charset of the line HTR model
+            "additional_tokens": 1,  # for decision layer = [<eot>, ], only for transferred charset
+            "encoder": {
+                "class": FCN_Encoder,
+                "dropout": 0.5,  # dropout rate for encoder
+                "nb_layers": 5,  # encoder
+            },
+            "h_max": 500,  # maximum height for encoder output (for 2D positional embedding)
+            "w_max": 1000,  # maximum width for encoder output (for 2D positional embedding)
+            "decoder": {
+                "class": GlobalHTADecoder,
+                "l_max": 15000,  # max predicted sequence (for 1D positional embedding)
+                "dec_num_layers": 8,  # number of transformer decoder layers
+                "dec_num_heads": 4,  # number of heads in transformer decoder layers
+                "dec_res_dropout": 0.1,  # dropout in transformer decoder layers
+                "dec_pred_dropout": 0.1,  # dropout rate before decision layer
+                "dec_att_dropout": 0.1,  # dropout rate in multi head attention
+                "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
+                "attention_win": 100,  # length of attention window
+                "enc_dim": 256,  # dimension of extracted features
+            },
+        },
+        "training": {
+            "data": {
+                "batch_size": 2,  # mini-batch size for training
                 "load_in_memory": True,  # Load all images in CPU memory
+                "worker_per_gpu": 4,  # Num of parallel processes per gpu for data loading
                 "preprocessings": [
                     {
                         "type": Preprocessing.MaxResize,
                         "max_width": 2000,
                         "max_height": 2000,
-                    },
+                    }
                 ],
                 "augmentation": True,
             },
-            "tokens": None,
-        },
-        "model_params": {
-            "models": {
-                "encoder": FCN_Encoder,
-                "decoder": GlobalHTADecoder,
+            "device": {
+                "use_ddp": False,  # Use DistributedDataParallel
+                "ddp_port": "20027",
+                "use_amp": True,  # Enable automatic mix-precision
+                "nb_gpu": 0,
+                "force_cpu": True,  # True for debug purposes
+            },
+            "metrics": {
+                "train": [
+                    "loss_ce",
+                    "cer",
+                    "wer",
+                    "wer_no_punct",
+                ],  # Metrics name for training
+                "eval": [
+                    "cer",
+                    "wer",
+                    "wer_no_punct",
+                ],  # Metrics name for evaluation on validation set during training
+            },
+            "validation": {
+                "eval_on_valid": True,  # Whether to eval and logs metrics on validation set during training or not
+                "eval_on_valid_interval": 2,  # Interval (in epochs) to evaluate during training
+                "set_name_focus_metric": "training-val",
             },
-            "transfer_learning": None,
-            "transfered_charset": True,  # Transfer learning of the decision layer based on charset of the line HTR model
-            "additional_tokens": 1,  # for decision layer = [<eot>, ], only for transferred charset
-            "dropout": 0.5,  # dropout rate for encoder
-            "enc_dim": 256,  # dimension of extracted features
-            "nb_layers": 5,  # encoder
-            "h_max": 500,  # maximum height for encoder output (for 2D positional embedding)
-            "w_max": 1000,  # maximum width for encoder output (for 2D positional embedding)
-            "l_max": 15000,  # max predicted sequence (for 1D positional embedding)
-            "dec_num_layers": 8,  # number of transformer decoder layers
-            "dec_num_heads": 4,  # number of heads in transformer decoder layers
-            "dec_res_dropout": 0.1,  # dropout in transformer decoder layers
-            "dec_pred_dropout": 0.1,  # dropout rate before decision layer
-            "dec_att_dropout": 0.1,  # dropout rate in multi head attention
-            "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
-            "attention_win": 100,  # length of attention window
-        },
-        "training_params": {
             "output_folder": "dan_trained_model",  # folder name for checkpoint and results
+            "gradient_clipping": {},
             "max_nb_epochs": 4,  # maximum number of epochs before to stop
             "load_epoch": "last",  # ["best", "last"]: last to continue training, best to evaluate
-            "batch_size": 2,  # mini-batch size for training
-            "use_ddp": False,  # Use DistributedDataParallel
-            "nb_gpu": 0,
             "optimizers": {
                 "all": {
                     "class": Adam,
@@ -100,28 +127,13 @@ def training_config():
                 },
             },
             "lr_schedulers": None,  # Learning rate schedulers
-            "eval_on_valid": True,  # Whether to eval and logs metrics on validation set during training or not
-            "eval_on_valid_interval": 2,  # Interval (in epochs) to evaluate during training
-            "set_name_focus_metric": "training-val",  # Which dataset to focus on to select best weights
-            "train_metrics": [
-                "loss_ce",
-                "cer",
-                "wer",
-                "wer_no_punct",
-            ],  # Metrics name for training
-            "eval_metrics": [
-                "cer",
-                "wer",
-                "wer_no_punct",
-            ],  # Metrics name for evaluation on validation set during training
-            "force_cpu": True,  # True for debug purposes
-            "max_char_prediction": 30,  # max number of token prediction
             # Keep teacher forcing rate to 20% during whole training
             "label_noise_scheduler": {
                 "min_error_rate": 0.2,
                 "max_error_rate": 0.2,
                 "total_num_steps": 5e4,
             },
+            "transfer_learning": None,
         },
     }
 
diff --git a/tests/test_training.py b/tests/test_training.py
index 9815ac08ca788e8bf6e140676b56583ff6da6abc..3c3b20acffd02326e310f70d0ca738646b1d836a 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -88,8 +88,8 @@ def test_train_and_test(
     tmp_path,
 ):
     # Use the tmp_path as base folder
-    training_config["training_params"]["output_folder"] = str(
-        tmp_path / training_config["training_params"]["output_folder"]
+    training_config["training"]["output_folder"] = str(
+        tmp_path / training_config["training"]["output_folder"]
     )
 
     train_and_test(0, training_config)
@@ -99,7 +99,7 @@ def test_train_and_test(
         expected_model = torch.load(FIXTURES / "training" / "models" / model_name)
         trained_model = torch.load(
             tmp_path
-            / training_config["training_params"]["output_folder"]
+            / training_config["training"]["output_folder"]
             / "checkpoints"
             / model_name,
         )
@@ -114,7 +114,9 @@ def test_train_and_test(
                 expected_tensor,
             ) in zip(trained.items(), expected.items()):
                 assert trained_param == expected_param
-                assert torch.allclose(trained_tensor, expected_tensor, atol=1e-03)
+                assert torch.allclose(
+                    trained_tensor, expected_tensor, rtol=1e-05, atol=1e-03
+                )
 
         # Check the optimizer encoder and decoder state dicts
         for optimizer_part in [
@@ -169,7 +171,7 @@ def test_train_and_test(
     ):
         with (
             tmp_path
-            / training_config["training_params"]["output_folder"]
+            / training_config["training"]["output_folder"]
             / "results"
             / f"predict_training-{split_name}_0.yaml"
         ).open() as f:
@@ -184,7 +186,7 @@ def test_train_and_test(
     # Check that the inference parameters file is correct
     with (
         tmp_path
-        / training_config["training_params"]["output_folder"]
+        / training_config["training"]["output_folder"]
         / "results"
         / "inference_parameters.yml"
     ).open() as f: