Skip to content
Snippets Groups Projects
Commit 46b25cf7 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Bastien Abadie
Browse files

Robust mlflow requests

parent 18f354a2
No related branches found
No related tags found
1 merge request!59Robust mlflow requests
......@@ -3,11 +3,22 @@ import os
from contextlib import contextmanager
import mlflow
from mlflow.exceptions import MlflowException
import requests
from mlflow.environment_variables import MLFLOW_HTTP_REQUEST_MAX_RETRIES
from dan import logger
def make_mlflow_request(mlflow_method, *args, **kwargs):
"""
Encapsulate MLflow HTTP requests to prevent them from crashing the whole training process.
"""
try:
mlflow_method(*args, **kwargs)
except requests.exceptions.ConnectionError as e:
logger.error(f"Call to `{str(mlflow_method)}` failed with error: {str(e)}")
def setup_environment(config: dict):
"""
Get the necessary variables from the config file and put them in the environment variables
......@@ -24,6 +35,13 @@ def setup_environment(config: dict):
if config_key in config:
os.environ[variable_name] = config[config_key]
# Check max retry setting
max_retries = MLFLOW_HTTP_REQUEST_MAX_RETRIES.get()
if max_retries and int(max_retries) <= 1:
logger.warning(
f"The maximum number of retries for MLflow HTTP requests is set to {max_retries}, which is low. Consider using a higher value."
)
def logging_metrics(
display_values: dict,
......@@ -42,10 +60,11 @@ def logging_metrics(
:param is_master: bool, makes sure you're on the right thread, defaults to False
"""
if mlflow_logging and is_master:
mlflow_values = {
f"{step}_{name}": value for name, value in display_values.items()
}
mlflow.log_metrics(mlflow_values, epoch)
make_mlflow_request(
mlflow_method=mlflow.log_metrics,
metrics={f"{step}_{name}": value for name, value in display_values.items()},
step=epoch,
)
def logging_tags_metrics(
......@@ -63,10 +82,10 @@ def logging_tags_metrics(
:param is_master: bool, makes sure you're on the right thread, defaults to False
"""
if mlflow_logging and is_master:
mlflow_values = {
f"{step}_{name}": value for name, value in display_values.items()
}
mlflow.set_tags(mlflow_values)
make_mlflow_request(
mlflow_method=mlflow.set_tags,
tags={f"{step}_{name}": value for name, value in display_values.items()},
)
@contextmanager
......@@ -83,13 +102,7 @@ def start_mlflow_run(config: dict):
# Set experiment from config
experiment_id = config.get("experiment_id")
assert experiment_id, "Missing MLflow experiment ID in the configuration"
try:
mlflow.set_experiment(experiment_id=experiment_id)
logger.info(f"Run Experiment ID : {experiment_id} on MLFlow")
except MlflowException as e:
logger.error(f"Couldn't set Mlflow experiment with ID: {experiment_id}")
raise e
# Start run
yield mlflow.start_run(run_name=config.get("run_name"))
yield mlflow.start_run(run_name=config.get("run_name"), experiment_id=experiment_id)
mlflow.end_run()
......@@ -24,7 +24,7 @@ try:
MLFLOW = True
logger.info("MLflow Logging available.")
from dan.mlflow import start_mlflow_run
from dan.mlflow import make_mlflow_request, start_mlflow_run
except ImportError:
MLFLOW = False
......@@ -71,16 +71,16 @@ def get_config():
Retrieve model configuration
"""
dataset_name = "esposalles"
dataset_level = "page"
dataset_variant = ""
dataset_path = "/home/training_data/ATR_paragraph/Esposalles"
dataset_level = "record"
dataset_variant = "_debug"
dataset_path = "."
params = {
"mlflow": {
"dataset_name": dataset_name,
"run_name": "Test log DAN",
"s3_endpoint_url": "",
"tracking_uri": "",
"experiment_id": "9",
"experiment_id": "0",
"aws_access_key_id": "",
"aws_secret_access_key": "",
},
......@@ -103,6 +103,11 @@ def get_config():
(dataset_name, "val"),
],
},
"test": {
"{}-test".format(dataset_name): [
(dataset_name, "test"),
],
},
"config": {
"load_in_memory": True, # Load all images in CPU memory
"worker_per_gpu": 4, # Num of parallel processes per gpu for data loading
......@@ -287,18 +292,27 @@ def run():
/ "labels.json"
)
with start_mlflow_run(config["mlflow"]) as run:
logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}")
mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]})
logger.info(f"Started MLflow run with ID ({run.info.run_id})")
make_mlflow_request(
mlflow_method=mlflow.set_tags, tags={"Dataset": dataset_name}
)
# Get the labels json file
with open(labels_path) as json_file:
labels_artifact = json.load(json_file)
# Log MLflow artifacts
mlflow.log_dict(config_artifact, "config.json")
mlflow.log_dict(labels_artifact, "labels.json")
for artifact, filename in [
(config_artifact, "config.json"),
(labels_artifact, "labels.json"),
]:
make_mlflow_request(
mlflow_method=mlflow.log_dict,
dictionary=artifact,
artifact_file=filename,
)
logger.info(f"Started MLflow run with ID ({run.info.run_id})")
if (
config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment