Skip to content
Snippets Groups Projects
Commit 46b25cf7 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Bastien Abadie
Browse files

Robust mlflow requests

parent 18f354a2
No related branches found
No related tags found
1 merge request!59Robust mlflow requests
...@@ -3,11 +3,22 @@ import os ...@@ -3,11 +3,22 @@ import os
from contextlib import contextmanager from contextlib import contextmanager
import mlflow import mlflow
from mlflow.exceptions import MlflowException import requests
from mlflow.environment_variables import MLFLOW_HTTP_REQUEST_MAX_RETRIES
from dan import logger from dan import logger
def make_mlflow_request(mlflow_method, *args, **kwargs):
"""
Encapsulate MLflow HTTP requests to prevent them from crashing the whole training process.
"""
try:
mlflow_method(*args, **kwargs)
except requests.exceptions.ConnectionError as e:
logger.error(f"Call to `{str(mlflow_method)}` failed with error: {str(e)}")
def setup_environment(config: dict): def setup_environment(config: dict):
""" """
Get the necessary variables from the config file and put them in the environment variables Get the necessary variables from the config file and put them in the environment variables
...@@ -24,6 +35,13 @@ def setup_environment(config: dict): ...@@ -24,6 +35,13 @@ def setup_environment(config: dict):
if config_key in config: if config_key in config:
os.environ[variable_name] = config[config_key] os.environ[variable_name] = config[config_key]
# Check max retry setting
max_retries = MLFLOW_HTTP_REQUEST_MAX_RETRIES.get()
if max_retries and int(max_retries) <= 1:
logger.warning(
f"The maximum number of retries for MLflow HTTP requests is set to {max_retries}, which is low. Consider using a higher value."
)
def logging_metrics( def logging_metrics(
display_values: dict, display_values: dict,
...@@ -42,10 +60,11 @@ def logging_metrics( ...@@ -42,10 +60,11 @@ def logging_metrics(
:param is_master: bool, makes sure you're on the right thread, defaults to False :param is_master: bool, makes sure you're on the right thread, defaults to False
""" """
if mlflow_logging and is_master: if mlflow_logging and is_master:
mlflow_values = { make_mlflow_request(
f"{step}_{name}": value for name, value in display_values.items() mlflow_method=mlflow.log_metrics,
} metrics={f"{step}_{name}": value for name, value in display_values.items()},
mlflow.log_metrics(mlflow_values, epoch) step=epoch,
)
def logging_tags_metrics( def logging_tags_metrics(
...@@ -63,10 +82,10 @@ def logging_tags_metrics( ...@@ -63,10 +82,10 @@ def logging_tags_metrics(
:param is_master: bool, makes sure you're on the right thread, defaults to False :param is_master: bool, makes sure you're on the right thread, defaults to False
""" """
if mlflow_logging and is_master: if mlflow_logging and is_master:
mlflow_values = { make_mlflow_request(
f"{step}_{name}": value for name, value in display_values.items() mlflow_method=mlflow.set_tags,
} tags={f"{step}_{name}": value for name, value in display_values.items()},
mlflow.set_tags(mlflow_values) )
@contextmanager @contextmanager
...@@ -83,13 +102,7 @@ def start_mlflow_run(config: dict): ...@@ -83,13 +102,7 @@ def start_mlflow_run(config: dict):
# Set experiment from config # Set experiment from config
experiment_id = config.get("experiment_id") experiment_id = config.get("experiment_id")
assert experiment_id, "Missing MLflow experiment ID in the configuration" assert experiment_id, "Missing MLflow experiment ID in the configuration"
try:
mlflow.set_experiment(experiment_id=experiment_id)
logger.info(f"Run Experiment ID : {experiment_id} on MLFlow")
except MlflowException as e:
logger.error(f"Couldn't set Mlflow experiment with ID: {experiment_id}")
raise e
# Start run # Start run
yield mlflow.start_run(run_name=config.get("run_name")) yield mlflow.start_run(run_name=config.get("run_name"), experiment_id=experiment_id)
mlflow.end_run() mlflow.end_run()
...@@ -24,7 +24,7 @@ try: ...@@ -24,7 +24,7 @@ try:
MLFLOW = True MLFLOW = True
logger.info("MLflow Logging available.") logger.info("MLflow Logging available.")
from dan.mlflow import start_mlflow_run from dan.mlflow import make_mlflow_request, start_mlflow_run
except ImportError: except ImportError:
MLFLOW = False MLFLOW = False
...@@ -71,16 +71,16 @@ def get_config(): ...@@ -71,16 +71,16 @@ def get_config():
Retrieve model configuration Retrieve model configuration
""" """
dataset_name = "esposalles" dataset_name = "esposalles"
dataset_level = "page" dataset_level = "record"
dataset_variant = "" dataset_variant = "_debug"
dataset_path = "/home/training_data/ATR_paragraph/Esposalles" dataset_path = "."
params = { params = {
"mlflow": { "mlflow": {
"dataset_name": dataset_name, "dataset_name": dataset_name,
"run_name": "Test log DAN", "run_name": "Test log DAN",
"s3_endpoint_url": "", "s3_endpoint_url": "",
"tracking_uri": "", "tracking_uri": "",
"experiment_id": "9", "experiment_id": "0",
"aws_access_key_id": "", "aws_access_key_id": "",
"aws_secret_access_key": "", "aws_secret_access_key": "",
}, },
...@@ -103,6 +103,11 @@ def get_config(): ...@@ -103,6 +103,11 @@ def get_config():
(dataset_name, "val"), (dataset_name, "val"),
], ],
}, },
"test": {
"{}-test".format(dataset_name): [
(dataset_name, "test"),
],
},
"config": { "config": {
"load_in_memory": True, # Load all images in CPU memory "load_in_memory": True, # Load all images in CPU memory
"worker_per_gpu": 4, # Num of parallel processes per gpu for data loading "worker_per_gpu": 4, # Num of parallel processes per gpu for data loading
...@@ -287,18 +292,27 @@ def run(): ...@@ -287,18 +292,27 @@ def run():
/ "labels.json" / "labels.json"
) )
with start_mlflow_run(config["mlflow"]) as run: with start_mlflow_run(config["mlflow"]) as run:
logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}") logger.info(f"Started MLflow run with ID ({run.info.run_id})")
mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]})
make_mlflow_request(
mlflow_method=mlflow.set_tags, tags={"Dataset": dataset_name}
)
# Get the labels json file # Get the labels json file
with open(labels_path) as json_file: with open(labels_path) as json_file:
labels_artifact = json.load(json_file) labels_artifact = json.load(json_file)
# Log MLflow artifacts # Log MLflow artifacts
mlflow.log_dict(config_artifact, "config.json") for artifact, filename in [
mlflow.log_dict(labels_artifact, "labels.json") (config_artifact, "config.json"),
(labels_artifact, "labels.json"),
]:
make_mlflow_request(
mlflow_method=mlflow.log_dict,
dictionary=artifact,
artifact_file=filename,
)
logger.info(f"Started MLflow run with ID ({run.info.run_id})")
if ( if (
config["training_params"]["use_ddp"] config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"] and not config["training_params"]["force_cpu"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment