From 5a7dd7ef0296ac09306a346b99029b20fc203969 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Tue, 4 Jul 2023 14:32:55 +0000 Subject: [PATCH] Correctly set mlflow logging --- dan/ocr/document/train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 4312c158..189bc2b9 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -258,18 +258,18 @@ def serialize_config(config): return serialized_config -def start_training(config) -> None: +def start_training(config, mlflow_logging: bool) -> None: if ( config["training_params"]["use_ddp"] and not config["training_params"]["force_cpu"] ): mp.spawn( train_and_test, - args=(config, True), + args=(config, mlflow_logging), nprocs=config["training_params"]["nb_gpu"], ) else: - train_and_test(0, config, True) + train_and_test(0, config, mlflow_logging) def run(): @@ -286,7 +286,7 @@ def run(): raise MLflowNotInstalled() if "mlflow" not in config: - start_training(config) + start_training(config, mlflow_logging=False) else: labels_path = ( Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json" @@ -314,4 +314,4 @@ def run(): dictionary=artifact, artifact_file=filename, ) - start_training(config) + start_training(config, mlflow_logging=True) -- GitLab