diff --git a/dan/manager/training.py b/dan/manager/training.py
index 40327d4e4d833b91edd72142c5939333ccbe38db..f1e3cb5bd269eb956558ca3f8dbd3ea96eddadc7 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -1,8 +1,8 @@
 # -*- coding: utf-8 -*-
-import json
 import os
 import random
 from copy import deepcopy
+from enum import Enum
 from time import time
 
 import numpy as np
@@ -454,7 +454,8 @@ class GenericTrainingManager:
 
     def save_params(self):
         """
-        Output text file containing a summary of all hyperparameters chosen for the training
+        Output a yaml file containing a summary of all hyperparameters chosen for the training
+        and a yaml file containing parameters used for inference
         """
 
         def compute_nb_params(module):
@@ -462,15 +463,28 @@ class GenericTrainingManager:
 
         def class_to_str_dict(my_dict):
             for key in my_dict.keys():
-                if callable(my_dict[key]):
+                if key == "preprocessings":
+                    my_dict[key] = [
+                        {
+                            key: value.value if isinstance(value, Enum) else value
+                            for key, value in preprocessing.items()
+                        }
+                        for preprocessing in my_dict[key]
+                    ]
+                elif callable(my_dict[key]):
                     my_dict[key] = my_dict[key].__name__
                 elif isinstance(my_dict[key], np.ndarray):
                     my_dict[key] = my_dict[key].tolist()
+                elif isinstance(my_dict[key], list) and isinstance(
+                    my_dict[key][0], tuple
+                ):
+                    my_dict[key] = [list(elt) for elt in my_dict[key]]
                 elif isinstance(my_dict[key], dict):
                     my_dict[key] = class_to_str_dict(my_dict[key])
             return my_dict
 
-        path = os.path.join(self.paths["results"], "params")
+        # Save training parameters
+        path = os.path.join(self.paths["results"], "training_parameters.yml")
         if os.path.isfile(path):
             return
         params = class_to_str_dict(my_dict=deepcopy(self.params))
@@ -483,8 +497,45 @@ class GenericTrainingManager:
             ]
             total_params += current_params
         params["model_params"]["total_params"] = "{:,}".format(total_params)
+        params["mean"] = self.dataset.mean.tolist()
+        params["std"] = self.dataset.std.tolist()
+        with open(path, "w") as f:
+            yaml.dump(params, f)
+
+        # Save inference parameters
+        path = os.path.join(self.paths["results"], "inference_parameters.yml")
+        if os.path.isfile(path):
+            return
+        inference_params = {
+            "parameters": {
+                "mean": params["mean"],
+                "std": params["std"],
+                "max_char_prediction": params["training_params"]["max_char_prediction"],
+                "encoder": {
+                    "dropout": params["model_params"]["dropout"],
+                },
+                "decoder": {
+                    key: params["model_params"][key]
+                    for key in [
+                        "enc_dim",
+                        "l_max",
+                        "h_max",
+                        "w_max",
+                        "dec_num_layers",
+                        "dec_num_heads",
+                        "dec_res_dropout",
+                        "dec_pred_dropout",
+                        "dec_att_dropout",
+                        "dec_dim_feedforward",
+                        "vocab_size",
+                        "attention_win",
+                    ]
+                },
+                "preprocessings": params["dataset_params"]["config"]["preprocessings"],
+            },
+        }
         with open(path, "w") as f:
-            json.dump(params, f, indent=4)
+            yaml.dump(inference_params, f)
 
     def backward_loss(self, loss, retain_graph=False):
         self.scaler.scale(loss).backward(retain_graph=retain_graph)
diff --git a/docs/get_started/training.md b/docs/get_started/training.md
index 71cdd974da602e45d95e16213bbc39d15eb801ac..300d478ca1d7ff85c56a365af0717fe986ebaa8e 100644
--- a/docs/get_started/training.md
+++ b/docs/get_started/training.md
@@ -45,33 +45,4 @@ Once the training is complete, you can apply a trained DAN model on an image.
 
 To do this, you will need to:
 
-1. Create a `parameters.yml` file using the parameters saved during training in the `params` file, located in `{training_params.output_folder}/results`. This file should have the following format:
-```yml
-version: 0.0.1
-parameters:
-  mean: [float, float, float]
-  std: [float, float, float]
-  max_char_prediction: int
-  encoder:
-    dropout: float
-  decoder:
-    enc_dim: int
-    l_max: int
-    dec_pred_dropout: float
-    attention_win: int
-    vocab_size: int
-    h_max: int
-    w_max: int
-    dec_num_layers: int
-    dec_dim_feedforward: int
-    dec_num_heads: int
-    dec_att_dropout: float
-    dec_res_dropout: float
-  preprocessings:
-    - type: str
-      max_height: int
-      max_width: int
-      fixed_height: int
-      fixed_width: int
-```
-2. Apply a trained DAN model on an image using the [predict command](../usage/predict.md).
+1. Apply a trained DAN model on an image using the [predict command](../usage/predict.md) and the `inference_parameters.yml` file, located in `{training_params.output_folder}/results`.
diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml
index f07baaecb8a44ed36eb56fced287d5a76ad138cf..f6014227d4f7e12114ac777a079dcb65891799f0 100644
--- a/tests/data/prediction/parameters.yml
+++ b/tests/data/prediction/parameters.yml
@@ -1,5 +1,4 @@
 ---
-version: 0.0.1
 parameters:
   mean: [166.8418783515498, 166.8418783515498, 166.8418783515498]
   std: [34.084189571536385, 34.084189571536385, 34.084189571536385]
diff --git a/tests/test_training.py b/tests/test_training.py
index 13c03c8c46e6ebde56a08c89b487ed9d901e02c5..89452ad5093e070185e1b3e694f70c0c330eabd8 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -9,7 +9,7 @@ from tests.conftest import FIXTURES
 
 
 @pytest.mark.parametrize(
-    "expected_best_model_name, expected_last_model_name, training_res, val_res, test_res",
+    "expected_best_model_name, expected_last_model_name, training_res, val_res, test_res, params_res",
     (
         (
             "best_0.pt",
@@ -41,6 +41,39 @@ from tests.conftest import FIXTURES
                 "wer_no_punct": 1.0,
                 "nb_samples": 2,
             },
+            {
+                "parameters": {
+                    "max_char_prediction": 30,
+                    "encoder": {"dropout": 0.5},
+                    "decoder": {
+                        "enc_dim": 256,
+                        "l_max": 15000,
+                        "h_max": 500,
+                        "w_max": 1000,
+                        "dec_num_layers": 8,
+                        "dec_num_heads": 4,
+                        "dec_res_dropout": 0.1,
+                        "dec_pred_dropout": 0.1,
+                        "dec_att_dropout": 0.1,
+                        "dec_dim_feedforward": 256,
+                        "vocab_size": 96,
+                        "attention_win": 100,
+                    },
+                    "preprocessings": [
+                        {
+                            "max_height": 2000,
+                            "max_width": 2000,
+                            "type": "max_resize",
+                        }
+                    ],
+                    "mean": [
+                        242.10595854671013,
+                        242.10595854671013,
+                        242.10595854671013,
+                    ],
+                    "std": [28.29919517652322, 28.29919517652322, 28.29919517652322],
+                },
+            },
         ),
     ),
 )
@@ -50,6 +83,7 @@ def test_train_and_test(
     training_res,
     val_res,
     test_res,
+    params_res,
     training_config,
     tmp_path,
 ):
@@ -146,3 +180,13 @@ def test_train_and_test(
                 if "time" not in metric
             }
             assert res == expected_res
+
+    # Check that the inference parameters file is correct
+    with (
+        tmp_path
+        / training_config["training_params"]["output_folder"]
+        / "results"
+        / "inference_parameters.yml"
+    ).open() as f:
+        res = yaml.safe_load(f)
+        assert res == params_res