Skip to content
Snippets Groups Projects
Commit c879ecfb authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Mélodie Boillet
Browse files

Separate mlflow deps

parent 3626aa70
No related branches found
No related tags found
1 merge request!157Separate mlflow deps
include requirements.txt include requirements.txt
include doc-requirements.txt include doc-requirements.txt
include mlflow-requirements.txt
include VERSION include VERSION
...@@ -19,16 +19,13 @@ from tqdm import tqdm ...@@ -19,16 +19,13 @@ from tqdm import tqdm
from dan.manager.metrics import MetricManager from dan.manager.metrics import MetricManager
from dan.manager.ocr import OCRDatasetManager from dan.manager.ocr import OCRDatasetManager
from dan.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
from dan.ocr.utils import LM_ind_to_str from dan.ocr.utils import LM_ind_to_str
from dan.schedulers import DropoutScheduler from dan.schedulers import DropoutScheduler
try: if MLFLOW_AVAILABLE:
import mlflow import mlflow
from dan.mlflow import logging_metrics, logging_tags_metrics
except ImportError:
pass
class GenericTrainingManager: class GenericTrainingManager:
def __init__(self, params): def __init__(self, params):
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import functools
import logging
import os import os
from contextlib import contextmanager from contextlib import contextmanager
import mlflow
import requests import requests
from mlflow.environment_variables import MLFLOW_HTTP_REQUEST_MAX_RETRIES
from dan import logger logger = logging.getLogger(__name__)
try:
import mlflow
from mlflow.environment_variables import MLFLOW_HTTP_REQUEST_MAX_RETRIES
MLFLOW_AVAILABLE = True
logger.info("MLflow logging is available.")
except ImportError:
MLFLOW_AVAILABLE = False
def mlflow_required(func):
"""
Always check that MLflow is available before executing the function.
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if not MLFLOW_AVAILABLE:
return
return func(self, *args, **kwargs)
return wrapper
@mlflow_required
def make_mlflow_request(mlflow_method, *args, **kwargs): def make_mlflow_request(mlflow_method, *args, **kwargs):
""" """
Encapsulate MLflow HTTP requests to prevent them from crashing the whole training process. Encapsulate MLflow HTTP requests to prevent them from crashing the whole training process.
...@@ -19,6 +43,7 @@ def make_mlflow_request(mlflow_method, *args, **kwargs): ...@@ -19,6 +43,7 @@ def make_mlflow_request(mlflow_method, *args, **kwargs):
logger.error(f"Call to `{str(mlflow_method)}` failed with error: {str(e)}") logger.error(f"Call to `{str(mlflow_method)}` failed with error: {str(e)}")
@mlflow_required
def setup_environment(config: dict): def setup_environment(config: dict):
""" """
Get the necessary variables from the config file and put them in the environment variables Get the necessary variables from the config file and put them in the environment variables
...@@ -43,6 +68,7 @@ def setup_environment(config: dict): ...@@ -43,6 +68,7 @@ def setup_environment(config: dict):
) )
@mlflow_required
def logging_metrics( def logging_metrics(
display_values: dict, display_values: dict,
step: str, step: str,
...@@ -67,6 +93,7 @@ def logging_metrics( ...@@ -67,6 +93,7 @@ def logging_metrics(
) )
@mlflow_required
def logging_tags_metrics( def logging_tags_metrics(
display_values: dict, display_values: dict,
step: str, step: str,
...@@ -88,6 +115,7 @@ def logging_tags_metrics( ...@@ -88,6 +115,7 @@ def logging_tags_metrics(
) )
@mlflow_required
@contextmanager @contextmanager
def start_mlflow_run(config: dict): def start_mlflow_run(config: dict):
""" """
......
...@@ -10,23 +10,18 @@ import torch ...@@ -10,23 +10,18 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.optim import Adam from torch.optim import Adam
from dan import logger
from dan.decoder import GlobalHTADecoder from dan.decoder import GlobalHTADecoder
from dan.encoder import FCN_Encoder from dan.encoder import FCN_Encoder
from dan.manager.training import Manager from dan.manager.training import Manager
from dan.mlflow import MLFLOW_AVAILABLE
from dan.schedulers import exponential_dropout_scheduler from dan.schedulers import exponential_dropout_scheduler
from dan.transforms import aug_config from dan.transforms import aug_config
from dan.utils import MLflowNotInstalled from dan.utils import MLflowNotInstalled
try: if MLFLOW_AVAILABLE:
import mlflow import mlflow
MLFLOW = True
logger.info("MLflow Logging available.")
from dan.mlflow import make_mlflow_request, start_mlflow_run from dan.mlflow import make_mlflow_request, start_mlflow_run
except ImportError:
MLFLOW = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -271,6 +266,20 @@ def serialize_config(config): ...@@ -271,6 +266,20 @@ def serialize_config(config):
return serialized_config return serialized_config
def start_training(config) -> None:
if (
config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"]
):
mp.spawn(
train_and_test,
args=(config, True),
nprocs=config["training_params"]["nb_gpu"],
)
else:
train_and_test(0, config, True)
def run(): def run():
""" """
Main program, training a new model, using a valid configuration Main program, training a new model, using a valid configuration
...@@ -278,7 +287,15 @@ def run(): ...@@ -278,7 +287,15 @@ def run():
config, dataset_name = get_config() config, dataset_name = get_config()
if MLFLOW and "mlflow" in config: if "mlflow" in config and not MLFLOW_AVAILABLE:
logger.error(
"Cannot log to MLflow. Please install the `mlflow` extra requirements."
)
raise MLflowNotInstalled()
if "mlflow" not in config:
start_training(config)
else:
labels_path = ( labels_path = (
Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json" Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
) )
...@@ -305,31 +322,4 @@ def run(): ...@@ -305,31 +322,4 @@ def run():
dictionary=artifact, dictionary=artifact,
artifact_file=filename, artifact_file=filename,
) )
if ( start_training(config)
config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"]
):
mp.spawn(
train_and_test,
args=(config, True),
nprocs=config["training_params"]["nb_gpu"],
)
else:
train_and_test(0, config, True)
elif "mlflow" in config:
logger.error(
"Cannot log to MLflow as the `mlflow` module was not found in your environment."
)
raise MLflowNotInstalled()
else:
if (
config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"]
):
mp.spawn(
train_and_test,
args=(config, True),
nprocs=config["training_params"]["nb_gpu"],
)
else:
train_and_test(0, config, True)
...@@ -235,7 +235,14 @@ The following configuration is used by default when using the `teklia-dan train ...@@ -235,7 +235,14 @@ The following configuration is used by default when using the `teklia-dan train
## MLFlow logging ## MLFlow logging
To log your experiment on MLFlow, update the following arguments. To log your experiment on MLFlow, you need to:
- install the extra requirements via
```shell
$ pip install .[mlflow]
```
- update the following arguments:
| Name | Description | Type | Default | | Name | Description | Type | Default |
| ------------------------------ | ------------------------------------ | ----- | ------- | | ------------------------------ | ------------------------------------ | ----- | ------- |
......
mlflow-skinny==2.2.2
pandas==2.0.0
...@@ -2,11 +2,8 @@ arkindex-export==0.1.3 ...@@ -2,11 +2,8 @@ arkindex-export==0.1.3
boto3==1.26.124 boto3==1.26.124
editdistance==0.6.2 editdistance==0.6.2
imageio==2.26.1 imageio==2.26.1
mlflow-skinny==2.2.2
numpy==1.24.3 numpy==1.24.3
opencv-python==4.7.0.72 opencv-python==4.7.0.72
# Needed for mlflow
pandas==2.0.0
PyYAML==6.0 PyYAML==6.0
scipy==1.10.1 scipy==1.10.1
tensorboard==2.12.2 tensorboard==2.12.2
......
...@@ -28,5 +28,8 @@ setup( ...@@ -28,5 +28,8 @@ setup(
"teklia-dan=dan.cli:main", "teklia-dan=dan.cli:main",
] ]
}, },
extras_require={"docs": parse_requirements("doc-requirements.txt")}, extras_require={
"docs": parse_requirements("doc-requirements.txt"),
"mlflow": parse_requirements("mlflow-requirements.txt"),
},
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment