From 82d306552f76dd0ca947652bc458ee9bd17cca41 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Fri, 10 Feb 2023 11:05:10 +0100
Subject: [PATCH] do not force mlflow

---
 dan/mlflow.py             |  6 ++++++
 dan/ocr/document/train.py | 22 ++++++++++++++++++----
 2 files changed, 24 insertions(+), 4 deletions(-)

diff --git a/dan/mlflow.py b/dan/mlflow.py
index 063c2b91..a0fba3cc 100644
--- a/dan/mlflow.py
+++ b/dan/mlflow.py
@@ -8,6 +8,12 @@ from mlflow.exceptions import MlflowException
 from dan import logger
 
 
+class MLflowNotInstalled(Exception):
+    """
+    Raised when MLflow logging was requested but the module was not installed
+    """
+
+
 def setup_environment(config: dict):
     """
     Get the necessary variables from the config file and put them in the environment variables
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index a0366aac..eff73958 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -5,20 +5,28 @@ import random
 from copy import deepcopy
 from pathlib import Path
 
-import mlflow
 import numpy as np
 import torch
 import torch.multiprocessing as mp
 from torch.optim import Adam
 
+from dan import logger
 from dan.decoder import GlobalHTADecoder
 from dan.manager.ocr import OCRDataset, OCRDatasetManager
 from dan.manager.training import Manager
-from dan.mlflow import start_mlflow_run
+from dan.mlflow import MLflowNotInstalled, start_mlflow_run
 from dan.models import FCN_Encoder
 from dan.schedulers import exponential_dropout_scheduler
 from dan.transforms import aug_config
 
+try:
+    import mlflow
+
+    MLFLOW = True
+    logger.info("MLflow Logging available.")
+except ImportError:
+    MLFLOW = False
+
 logger = logging.getLogger(__name__)
 
 
@@ -272,13 +280,14 @@ def run():
     config = get_config()
     config_artifact = serialize_config(config)
     labels_artifact = ""
-    dataset_name = config["mlflow"]["dataset_name"]
+    # The only key of this dict is the name of the dataset
+    dataset_name = config_artifact["dataset_params"]["datasets"].keys()[0]
     labels_path = (
         Path(config_artifact["dataset_params"]["datasets"][dataset_name])
         / "labels.json"
     )
 
-    if config["mlflow"]:
+    if MLFLOW and config["mlflow"]:
         with start_mlflow_run(config["mlflow"]) as run:
             logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}")
             mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]})
@@ -303,6 +312,11 @@ def run():
                 )
             else:
                 train_and_test(0, config, True)
+    elif config["mlflow"]:
+        logger.error(
+            "Cannot log to MLflow as the `mlflow` module was not found in your environment."
+        )
+        raise MLflowNotInstalled()
     else:
         if (
             config["training_params"]["use_ddp"]
-- 
GitLab