diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index 4312c15836e5b9533ed5c08e15b7105b7614ebc5..189bc2b993c2cbcd37a2d21097ba40a72b32c773 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -258,18 +258,18 @@ def serialize_config(config):
     return serialized_config
 
 
-def start_training(config) -> None:
+def start_training(config, mlflow_logging: bool) -> None:
     if (
         config["training_params"]["use_ddp"]
         and not config["training_params"]["force_cpu"]
     ):
         mp.spawn(
             train_and_test,
-            args=(config, True),
+            args=(config, mlflow_logging),
             nprocs=config["training_params"]["nb_gpu"],
         )
     else:
-        train_and_test(0, config, True)
+        train_and_test(0, config, mlflow_logging)
 
 
 def run():
@@ -286,7 +286,7 @@ def run():
         raise MLflowNotInstalled()
 
     if "mlflow" not in config:
-        start_training(config)
+        start_training(config, mlflow_logging=False)
     else:
         labels_path = (
             Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
@@ -314,4 +314,4 @@ def run():
                     dictionary=artifact,
                     artifact_file=filename,
                 )
-            start_training(config)
+            start_training(config, mlflow_logging=True)