From 30a8efce84de4b14f20d92b7e45db33923f4ed5b Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Wed, 15 Feb 2023 17:31:20 +0100
Subject: [PATCH] more logs

---
 dan/manager/training.py   |  7 -------
 dan/mlflow.py             |  6 ++++--
 dan/ocr/document/train.py | 11 +++++++----
 3 files changed, 11 insertions(+), 13 deletions(-)

diff --git a/dan/manager/training.py b/dan/manager/training.py
index a1c3b1e4..7b565c41 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -31,13 +31,6 @@ try:
 except ImportError:
     pass
 
-try:
-    import mlflow
-
-    from dan.mlflow import logging_metrics, logging_tags_metrics
-except ImportError:
-    pass
-
 
 class GenericTrainingManager:
     def __init__(self, params):
diff --git a/dan/mlflow.py b/dan/mlflow.py
index 95b41bf8..b89342ed 100644
--- a/dan/mlflow.py
+++ b/dan/mlflow.py
@@ -91,7 +91,9 @@ def logging_tags_metrics(
 @contextmanager
 def start_mlflow_run(config: dict):
     """
-    Create an MLflow execution context with the parameters contained in the config file
+    Create an MLflow execution context with the parameters contained in the config file.
+
+    Yields the active MLflow run, as well as a boolean saying whether a new one was created.
 
     :param config: dict, the config of the model
     """
@@ -116,5 +118,5 @@ def start_mlflow_run(config: dict):
     # Start run
     yield mlflow.start_run(
         run_id=run_id, run_name=run_name, experiment_id=experiment_id
-    )
+    ), run_id is None
     mlflow.end_run()
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index eaea5016..3eb15552 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -304,11 +304,14 @@ def run():
         labels_path = (
             Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
         )
-        with start_mlflow_run(config["mlflow"]) as run:
-            logger.info(f"Started MLflow run with ID ({run.info.run_id})")            
+        with start_mlflow_run(config["mlflow"]) as (run, created):
+            if created:
+                logger.info(f"Started MLflow run with ID ({run.info.run_id})")
+            else:
+                logger.info(f"Resumed MLflow run with ID ({run.info.run_id})")
+
             make_mlflow_request(
-                mlflow_method=mlflow.set_tags,
-                tags={"Dataset": dataset_name}
+                mlflow_method=mlflow.set_tags, tags={"Dataset": dataset_name}
             )
             # Get the labels json file
             with open(labels_path) as json_file:
-- 
GitLab