From ab27436624b254d5fa20b59b523262a00574c702 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Fri, 10 Feb 2023 12:50:21 +0100
Subject: [PATCH] allow resuming from existing run

---
 dan/mlflow.py             | 14 +++++++++++++-
 dan/ocr/document/train.py |  2 +-
 2 files changed, 14 insertions(+), 2 deletions(-)

diff --git a/dan/mlflow.py b/dan/mlflow.py
index 9fdd4a65..55efde33 100644
--- a/dan/mlflow.py
+++ b/dan/mlflow.py
@@ -91,10 +91,22 @@ def start_mlflow_run(config: dict):
     # Set needed variables in environment
     setup_environment(config)
 
+    run_name, run_id = config.get("run_name"), config.get("run_id")
+
+    if run_id:
+        logger.info(f"Will resume run ({run_id}).")
+
+        if run_name:
+            logger.warning(
+                "Run_name will be ignored since you specified a run_id to resume from."
+            )
+
     # Set experiment from config
     experiment_id = config.get("experiment_id")
     assert experiment_id, "Missing MLflow experiment ID in the configuration"
 
     # Start run
-    yield mlflow.start_run(run_name=config.get("run_name"), experiment_id=experiment_id)
+    yield mlflow.start_run(
+        run_id=run_id, run_name=run_name, experiment_id=experiment_id
+    )
     mlflow.end_run()
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index 72abb8b8..04eb24ad 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -76,8 +76,8 @@ def get_config():
     dataset_path = "."
     params = {
         "mlflow": {
-            "dataset_name": dataset_name,
             "run_name": "Test log DAN",
+            "run_id": None,
             "s3_endpoint_url": "",
             "tracking_uri": "",
             "experiment_id": "0",
-- 
GitLab