From 5a7dd7ef0296ac09306a346b99029b20fc203969 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Tue, 4 Jul 2023 14:32:55 +0000
Subject: [PATCH] Correctly set mlflow logging

---
 dan/ocr/document/train.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index 4312c158..189bc2b9 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -258,18 +258,18 @@ def serialize_config(config):
     return serialized_config
 
 
-def start_training(config) -> None:
+def start_training(config, mlflow_logging: bool) -> None:
     if (
         config["training_params"]["use_ddp"]
         and not config["training_params"]["force_cpu"]
     ):
         mp.spawn(
             train_and_test,
-            args=(config, True),
+            args=(config, mlflow_logging),
             nprocs=config["training_params"]["nb_gpu"],
         )
     else:
-        train_and_test(0, config, True)
+        train_and_test(0, config, mlflow_logging)
 
 
 def run():
@@ -286,7 +286,7 @@ def run():
         raise MLflowNotInstalled()
 
     if "mlflow" not in config:
-        start_training(config)
+        start_training(config, mlflow_logging=False)
     else:
         labels_path = (
             Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
@@ -314,4 +314,4 @@ def run():
                     dictionary=artifact,
                     artifact_file=filename,
                 )
-            start_training(config)
+            start_training(config, mlflow_logging=True)
-- 
GitLab