diff --git a/MANIFEST.in b/MANIFEST.in index 889ea0afe24d0a5a0ba61700c5d3a88d8e1960ed..81bfdbd1a444551752862576a06dc160484ff009 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ include requirements.txt include doc-requirements.txt +include mlflow-requirements.txt include VERSION diff --git a/dan/manager/training.py b/dan/manager/training.py index 67b9f6bf1b38c8dce424c94275676e26403c5248..fa3fbe12a562c7445dc8a281fb80909f4993e3f7 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -19,16 +19,13 @@ from tqdm import tqdm from dan.manager.metrics import MetricManager from dan.manager.ocr import OCRDatasetManager +from dan.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics from dan.ocr.utils import LM_ind_to_str from dan.schedulers import DropoutScheduler -try: +if MLFLOW_AVAILABLE: import mlflow - from dan.mlflow import logging_metrics, logging_tags_metrics -except ImportError: - pass - class GenericTrainingManager: def __init__(self, params): diff --git a/dan/mlflow.py b/dan/mlflow.py index b89342ed122182518c1fa0068ae45da52983fb71..fe166e8364b3f04a3d5dfa657618c143b7db6d9f 100644 --- a/dan/mlflow.py +++ b/dan/mlflow.py @@ -1,14 +1,38 @@ # -*- coding: utf-8 -*- +import functools +import logging import os from contextlib import contextmanager -import mlflow import requests -from mlflow.environment_variables import MLFLOW_HTTP_REQUEST_MAX_RETRIES -from dan import logger +logger = logging.getLogger(__name__) +try: + import mlflow + from mlflow.environment_variables import MLFLOW_HTTP_REQUEST_MAX_RETRIES + MLFLOW_AVAILABLE = True + logger.info("MLflow logging is available.") +except ImportError: + MLFLOW_AVAILABLE = False + + +def mlflow_required(func): + """ + Always check that MLflow is available before executing the function. + """ + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if not MLFLOW_AVAILABLE: + return + return func(self, *args, **kwargs) + + return wrapper + + +@mlflow_required def make_mlflow_request(mlflow_method, *args, **kwargs): """ Encapsulate MLflow HTTP requests to prevent them from crashing the whole training process. @@ -19,6 +43,7 @@ def make_mlflow_request(mlflow_method, *args, **kwargs): logger.error(f"Call to `{str(mlflow_method)}` failed with error: {str(e)}") +@mlflow_required def setup_environment(config: dict): """ Get the necessary variables from the config file and put them in the environment variables @@ -43,6 +68,7 @@ def setup_environment(config: dict): ) +@mlflow_required def logging_metrics( display_values: dict, step: str, @@ -67,6 +93,7 @@ def logging_metrics( ) +@mlflow_required def logging_tags_metrics( display_values: dict, step: str, @@ -88,6 +115,7 @@ def logging_tags_metrics( ) +@mlflow_required @contextmanager def start_mlflow_run(config: dict): """ diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 282c8b2630e254ea4215de287861589f3eff9fb4..4745fb97709e9b83a35750ffc646c474d74742ea 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -10,23 +10,18 @@ import torch import torch.multiprocessing as mp from torch.optim import Adam -from dan import logger from dan.decoder import GlobalHTADecoder from dan.encoder import FCN_Encoder from dan.manager.training import Manager +from dan.mlflow import MLFLOW_AVAILABLE from dan.schedulers import exponential_dropout_scheduler from dan.transforms import aug_config from dan.utils import MLflowNotInstalled -try: +if MLFLOW_AVAILABLE: import mlflow - MLFLOW = True - logger.info("MLflow Logging available.") - from dan.mlflow import make_mlflow_request, start_mlflow_run -except ImportError: - MLFLOW = False logger = logging.getLogger(__name__) @@ -271,6 +266,20 @@ def serialize_config(config): return serialized_config +def start_training(config) -> None: + if ( + config["training_params"]["use_ddp"] + and not config["training_params"]["force_cpu"] + ): + mp.spawn( + train_and_test, + args=(config, True), + nprocs=config["training_params"]["nb_gpu"], + ) + else: + train_and_test(0, config, True) + + def run(): """ Main program, training a new model, using a valid configuration @@ -278,7 +287,15 @@ def run(): config, dataset_name = get_config() - if MLFLOW and "mlflow" in config: + if "mlflow" in config and not MLFLOW_AVAILABLE: + logger.error( + "Cannot log to MLflow. Please install the `mlflow` extra requirements." + ) + raise MLflowNotInstalled() + + if "mlflow" not in config: + start_training(config) + else: labels_path = ( Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json" ) @@ -305,31 +322,4 @@ def run(): dictionary=artifact, artifact_file=filename, ) - if ( - config["training_params"]["use_ddp"] - and not config["training_params"]["force_cpu"] - ): - mp.spawn( - train_and_test, - args=(config, True), - nprocs=config["training_params"]["nb_gpu"], - ) - else: - train_and_test(0, config, True) - elif "mlflow" in config: - logger.error( - "Cannot log to MLflow as the `mlflow` module was not found in your environment." - ) - raise MLflowNotInstalled() - else: - if ( - config["training_params"]["use_ddp"] - and not config["training_params"]["force_cpu"] - ): - mp.spawn( - train_and_test, - args=(config, True), - nprocs=config["training_params"]["nb_gpu"], - ) - else: - train_and_test(0, config, True) + start_training(config) diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md index ba448484c39202e4e2dfbab76f0c7d9f3567f22f..7447e082b23ce35af67ace5203f91079fea4e788 100644 --- a/docs/usage/train/parameters.md +++ b/docs/usage/train/parameters.md @@ -235,7 +235,14 @@ The following configuration is used by default when using the `teklia-dan train ## MLFlow logging -To log your experiment on MLFlow, update the following arguments. +To log your experiment on MLFlow, you need to: +- install the extra requirements via + + ```shell + $ pip install .[mlflow] + ``` + +- update the following arguments: | Name | Description | Type | Default | | ------------------------------ | ------------------------------------ | ----- | ------- | diff --git a/mlflow-requirements.txt b/mlflow-requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3dcc340d6ef2b4736208dff59727df00d9ec0b92 --- /dev/null +++ b/mlflow-requirements.txt @@ -0,0 +1,2 @@ +mlflow-skinny==2.2.2 +pandas==2.0.0 diff --git a/requirements.txt b/requirements.txt index 7942bd5d8c872c60acb538bf3c3714900e3a9785..39d00baf289bc73fa171cf313269e4a9822ed31a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,11 +2,8 @@ arkindex-export==0.1.3 boto3==1.26.124 editdistance==0.6.2 imageio==2.26.1 -mlflow-skinny==2.2.2 numpy==1.24.3 opencv-python==4.7.0.72 -# Needed for mlflow -pandas==2.0.0 PyYAML==6.0 scipy==1.10.1 tensorboard==2.12.2 diff --git a/setup.py b/setup.py index d9a7455948797678587384ee3a4c406ca3c4a4cc..da61be6b543684d995dd535f4fc1e4c97964cbfd 100755 --- a/setup.py +++ b/setup.py @@ -28,5 +28,8 @@ setup( "teklia-dan=dan.cli:main", ] }, - extras_require={"docs": parse_requirements("doc-requirements.txt")}, + extras_require={ + "docs": parse_requirements("doc-requirements.txt"), + "mlflow": parse_requirements("mlflow-requirements.txt"), + }, )