diff --git a/dan/manager/training.py b/dan/manager/training.py
index a1c3b1e43f57c7e22da6680cd162f2550ac7d79b..7b565c41e9465531821f8af5ed739d1794670559 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -31,13 +31,6 @@ try:
 except ImportError:
     pass
 
-try:
-    import mlflow
-
-    from dan.mlflow import logging_metrics, logging_tags_metrics
-except ImportError:
-    pass
-
 
 class GenericTrainingManager:
     def __init__(self, params):
diff --git a/dan/mlflow.py b/dan/mlflow.py
index 95b41bf8eacb2ae8efb4fd865e93f66abf1e39bf..b89342ed122182518c1fa0068ae45da52983fb71 100644
--- a/dan/mlflow.py
+++ b/dan/mlflow.py
@@ -91,7 +91,9 @@ def logging_tags_metrics(
 @contextmanager
 def start_mlflow_run(config: dict):
     """
-    Create an MLflow execution context with the parameters contained in the config file
+    Create an MLflow execution context with the parameters contained in the config file.
+
+    Yields the active MLflow run, as well as a boolean saying whether a new one was created.
 
     :param config: dict, the config of the model
     """
@@ -116,5 +118,5 @@ def start_mlflow_run(config: dict):
     # Start run
     yield mlflow.start_run(
         run_id=run_id, run_name=run_name, experiment_id=experiment_id
-    )
+    ), run_id is None
     mlflow.end_run()
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index eaea50168a53bc6a37a04c2f38ec29f2a24dcc57..3eb15552ba975bbb294a00fd625a21ad7f1e4592 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -304,11 +304,14 @@ def run():
         labels_path = (
             Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
         )
-        with start_mlflow_run(config["mlflow"]) as run:
-            logger.info(f"Started MLflow run with ID ({run.info.run_id})")            
+        with start_mlflow_run(config["mlflow"]) as (run, created):
+            if created:
+                logger.info(f"Started MLflow run with ID ({run.info.run_id})")
+            else:
+                logger.info(f"Resumed MLflow run with ID ({run.info.run_id})")
+
             make_mlflow_request(
-                mlflow_method=mlflow.set_tags,
-                tags={"Dataset": dataset_name}
+                mlflow_method=mlflow.set_tags, tags={"Dataset": dataset_name}
             )
             # Get the labels json file
             with open(labels_path) as json_file: