From c879ecfb0ff2b615182b4a24d536d8f03fa4968e Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Mon, 12 Jun 2023 11:00:48 +0000
Subject: [PATCH] Separate mlflow deps

---
 MANIFEST.in                    |  1 +
 dan/manager/training.py        |  7 ++--
 dan/mlflow.py                  | 34 +++++++++++++++++--
 dan/ocr/document/train.py      | 62 ++++++++++++++--------------------
 docs/usage/train/parameters.md |  9 ++++-
 mlflow-requirements.txt        |  2 ++
 requirements.txt               |  3 --
 setup.py                       |  5 ++-
 8 files changed, 74 insertions(+), 49 deletions(-)
 create mode 100644 mlflow-requirements.txt

diff --git a/MANIFEST.in b/MANIFEST.in
index 889ea0af..81bfdbd1 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,3 +1,4 @@
 include requirements.txt
 include doc-requirements.txt
+include mlflow-requirements.txt
 include VERSION
diff --git a/dan/manager/training.py b/dan/manager/training.py
index 67b9f6bf..fa3fbe12 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -19,16 +19,13 @@ from tqdm import tqdm
 
 from dan.manager.metrics import MetricManager
 from dan.manager.ocr import OCRDatasetManager
+from dan.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
 from dan.ocr.utils import LM_ind_to_str
 from dan.schedulers import DropoutScheduler
 
-try:
+if MLFLOW_AVAILABLE:
     import mlflow
 
-    from dan.mlflow import logging_metrics, logging_tags_metrics
-except ImportError:
-    pass
-
 
 class GenericTrainingManager:
     def __init__(self, params):
diff --git a/dan/mlflow.py b/dan/mlflow.py
index b89342ed..fe166e83 100644
--- a/dan/mlflow.py
+++ b/dan/mlflow.py
@@ -1,14 +1,38 @@
 # -*- coding: utf-8 -*-
+import functools
+import logging
 import os
 from contextlib import contextmanager
 
-import mlflow
 import requests
-from mlflow.environment_variables import MLFLOW_HTTP_REQUEST_MAX_RETRIES
 
-from dan import logger
+logger = logging.getLogger(__name__)
 
+try:
+    import mlflow
+    from mlflow.environment_variables import MLFLOW_HTTP_REQUEST_MAX_RETRIES
 
+    MLFLOW_AVAILABLE = True
+    logger.info("MLflow logging is available.")
+except ImportError:
+    MLFLOW_AVAILABLE = False
+
+
+def mlflow_required(func):
+    """
+    Always check that MLflow is available before executing the function.
+    """
+
+    @functools.wraps(func)
+    def wrapper(self, *args, **kwargs):
+        if not MLFLOW_AVAILABLE:
+            return
+        return func(self, *args, **kwargs)
+
+    return wrapper
+
+
+@mlflow_required
 def make_mlflow_request(mlflow_method, *args, **kwargs):
     """
     Encapsulate MLflow HTTP requests to prevent them from crashing the whole training process.
@@ -19,6 +43,7 @@ def make_mlflow_request(mlflow_method, *args, **kwargs):
         logger.error(f"Call to `{str(mlflow_method)}` failed with error: {str(e)}")
 
 
+@mlflow_required
 def setup_environment(config: dict):
     """
     Get the necessary variables from the config file and put them in the environment variables
@@ -43,6 +68,7 @@ def setup_environment(config: dict):
         )
 
 
+@mlflow_required
 def logging_metrics(
     display_values: dict,
     step: str,
@@ -67,6 +93,7 @@ def logging_metrics(
         )
 
 
+@mlflow_required
 def logging_tags_metrics(
     display_values: dict,
     step: str,
@@ -88,6 +115,7 @@ def logging_tags_metrics(
         )
 
 
+@mlflow_required
 @contextmanager
 def start_mlflow_run(config: dict):
     """
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index 282c8b26..4745fb97 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -10,23 +10,18 @@ import torch
 import torch.multiprocessing as mp
 from torch.optim import Adam
 
-from dan import logger
 from dan.decoder import GlobalHTADecoder
 from dan.encoder import FCN_Encoder
 from dan.manager.training import Manager
+from dan.mlflow import MLFLOW_AVAILABLE
 from dan.schedulers import exponential_dropout_scheduler
 from dan.transforms import aug_config
 from dan.utils import MLflowNotInstalled
 
-try:
+if MLFLOW_AVAILABLE:
     import mlflow
 
-    MLFLOW = True
-    logger.info("MLflow Logging available.")
-
     from dan.mlflow import make_mlflow_request, start_mlflow_run
-except ImportError:
-    MLFLOW = False
 
 
 logger = logging.getLogger(__name__)
@@ -271,6 +266,20 @@ def serialize_config(config):
     return serialized_config
 
 
+def start_training(config) -> None:
+    if (
+        config["training_params"]["use_ddp"]
+        and not config["training_params"]["force_cpu"]
+    ):
+        mp.spawn(
+            train_and_test,
+            args=(config, True),
+            nprocs=config["training_params"]["nb_gpu"],
+        )
+    else:
+        train_and_test(0, config, True)
+
+
 def run():
     """
     Main program, training a new model, using a valid configuration
@@ -278,7 +287,15 @@ def run():
 
     config, dataset_name = get_config()
 
-    if MLFLOW and "mlflow" in config:
+    if "mlflow" in config and not MLFLOW_AVAILABLE:
+        logger.error(
+            "Cannot log to MLflow. Please install the `mlflow` extra requirements."
+        )
+        raise MLflowNotInstalled()
+
+    if "mlflow" not in config:
+        start_training(config)
+    else:
         labels_path = (
             Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
         )
@@ -305,31 +322,4 @@ def run():
                     dictionary=artifact,
                     artifact_file=filename,
                 )
-            if (
-                config["training_params"]["use_ddp"]
-                and not config["training_params"]["force_cpu"]
-            ):
-                mp.spawn(
-                    train_and_test,
-                    args=(config, True),
-                    nprocs=config["training_params"]["nb_gpu"],
-                )
-            else:
-                train_and_test(0, config, True)
-    elif "mlflow" in config:
-        logger.error(
-            "Cannot log to MLflow as the `mlflow` module was not found in your environment."
-        )
-        raise MLflowNotInstalled()
-    else:
-        if (
-            config["training_params"]["use_ddp"]
-            and not config["training_params"]["force_cpu"]
-        ):
-            mp.spawn(
-                train_and_test,
-                args=(config, True),
-                nprocs=config["training_params"]["nb_gpu"],
-            )
-        else:
-            train_and_test(0, config, True)
+            start_training(config)
diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md
index ba448484..7447e082 100644
--- a/docs/usage/train/parameters.md
+++ b/docs/usage/train/parameters.md
@@ -235,7 +235,14 @@ The following configuration is used by default when using the `teklia-dan train
 
 ## MLFlow logging
 
-To log your experiment on MLFlow, update the following arguments.
+To log your experiment on MLFlow, you need to:
+- install the extra requirements via
+
+    ```shell
+    $ pip install .[mlflow]
+    ```
+
+- update the following arguments:
 
 | Name                           | Description                          | Type  | Default |
 | ------------------------------ | ------------------------------------ | ----- | ------- |
diff --git a/mlflow-requirements.txt b/mlflow-requirements.txt
new file mode 100644
index 00000000..3dcc340d
--- /dev/null
+++ b/mlflow-requirements.txt
@@ -0,0 +1,2 @@
+mlflow-skinny==2.2.2
+pandas==2.0.0
diff --git a/requirements.txt b/requirements.txt
index 7942bd5d..39d00baf 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,11 +2,8 @@ arkindex-export==0.1.3
 boto3==1.26.124
 editdistance==0.6.2
 imageio==2.26.1
-mlflow-skinny==2.2.2
 numpy==1.24.3
 opencv-python==4.7.0.72
-# Needed for mlflow
-pandas==2.0.0
 PyYAML==6.0
 scipy==1.10.1
 tensorboard==2.12.2
diff --git a/setup.py b/setup.py
index d9a74559..da61be6b 100755
--- a/setup.py
+++ b/setup.py
@@ -28,5 +28,8 @@ setup(
             "teklia-dan=dan.cli:main",
         ]
     },
-    extras_require={"docs": parse_requirements("doc-requirements.txt")},
+    extras_require={
+        "docs": parse_requirements("doc-requirements.txt"),
+        "mlflow": parse_requirements("mlflow-requirements.txt"),
+    },
 )
-- 
GitLab