Skip to content
Snippets Groups Projects
Verified Commit 3277e063 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

last requests

parent 3dca228f
No related branches found
No related tags found
1 merge request!59Robust mlflow requests
This commit is part of merge request !59. Comments created here will be created in the context of that merge request.
...@@ -10,6 +10,9 @@ from dan import logger ...@@ -10,6 +10,9 @@ from dan import logger
def make_mlflow_request(mlflow_method, *args, **kwargs): def make_mlflow_request(mlflow_method, *args, **kwargs):
"""
Encapsulate MLflow HTTP requests to prevent them from crashing the whole training process.
"""
try: try:
mlflow_method(*args, **kwargs) mlflow_method(*args, **kwargs)
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
...@@ -50,11 +53,10 @@ def logging_metrics( ...@@ -50,11 +53,10 @@ def logging_metrics(
:param is_master: bool, makes sure you're on the right thread, defaults to False :param is_master: bool, makes sure you're on the right thread, defaults to False
""" """
if mlflow_logging and is_master: if mlflow_logging and is_master:
mlflow_values = {
f"{step}_{name}": value for name, value in display_values.items()
}
make_mlflow_request( make_mlflow_request(
mlflow_method=mlflow.log_metrics, metrics=mlflow_values, step=epoch mlflow_method=mlflow.log_metrics, metrics={
f"{step}_{name}": value for name, value in display_values.items()
}, step=epoch
) )
...@@ -73,10 +75,10 @@ def logging_tags_metrics( ...@@ -73,10 +75,10 @@ def logging_tags_metrics(
:param is_master: bool, makes sure you're on the right thread, defaults to False :param is_master: bool, makes sure you're on the right thread, defaults to False
""" """
if mlflow_logging and is_master: if mlflow_logging and is_master:
mlflow_values = { make_mlflow_request(mlflow_method=mlflow.set_tags, tags=
{
f"{step}_{name}": value for name, value in display_values.items() f"{step}_{name}": value for name, value in display_values.items()
} })
make_mlflow_request(mlflow_method=mlflow.set_tags, tags=mlflow_values)
@contextmanager @contextmanager
...@@ -93,15 +95,7 @@ def start_mlflow_run(config: dict): ...@@ -93,15 +95,7 @@ def start_mlflow_run(config: dict):
# Set experiment from config # Set experiment from config
experiment_id = config.get("experiment_id") experiment_id = config.get("experiment_id")
assert experiment_id, "Missing MLflow experiment ID in the configuration" assert experiment_id, "Missing MLflow experiment ID in the configuration"
try:
make_mlflow_request(
mlflow_method=mlflow.set_experiment, experiment_id=experiment_id
)
logger.info(f"Run Experiment ID : {experiment_id} on MLFlow")
except MlflowException as e:
logger.error(f"Couldn't set Mlflow experiment with ID: {experiment_id}")
raise e
# Start run # Start run
yield mlflow.start_run(run_name=config.get("run_name")) yield mlflow.start_run(run_name=config.get("run_name"), experiment_id=experiment_id)
mlflow.end_run() mlflow.end_run()
...@@ -24,7 +24,7 @@ try: ...@@ -24,7 +24,7 @@ try:
MLFLOW = True MLFLOW = True
logger.info("MLflow Logging available.") logger.info("MLflow Logging available.")
from dan.mlflow import start_mlflow_run from dan.mlflow import start_mlflow_run, make_mlflow_request
except ImportError: except ImportError:
MLFLOW = False MLFLOW = False
...@@ -70,17 +70,17 @@ def get_config(): ...@@ -70,17 +70,17 @@ def get_config():
""" """
Retrieve model configuration Retrieve model configuration
""" """
dataset_name = "esposalles" dataset_name = "synist"
dataset_level = "page" dataset_level = "manual_text_lines"
dataset_variant = "" dataset_variant = ""
dataset_path = "/home/training_data/ATR_paragraph/Esposalles" dataset_path = "."
params = { params = {
"mlflow": { "mlflow": {
"dataset_name": dataset_name, "dataset_name": dataset_name,
"run_name": "Test log DAN", "run_name": "Test log DAN",
"s3_endpoint_url": "", "s3_endpoint_url": "",
"tracking_uri": "", "tracking_uri": "",
"experiment_id": "9", "experiment_id": "0",
"aws_access_key_id": "", "aws_access_key_id": "",
"aws_secret_access_key": "", "aws_secret_access_key": "",
}, },
...@@ -287,18 +287,25 @@ def run(): ...@@ -287,18 +287,25 @@ def run():
/ "labels.json" / "labels.json"
) )
with start_mlflow_run(config["mlflow"]) as run: with start_mlflow_run(config["mlflow"]) as run:
logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}") logger.info(f"Started MLflow run with ID ({run.info.run_id})")
mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]})
make_mlflow_request(
mlflow_method=mlflow.set_tags,
tags={"Dataset": dataset_name}
)
# Get the labels json file # Get the labels json file
with open(labels_path) as json_file: with open(labels_path) as json_file:
labels_artifact = json.load(json_file) labels_artifact = json.load(json_file)
# Log MLflow artifacts # Log MLflow artifacts
mlflow.log_dict(config_artifact, "config.json") for artifact, filename in [(config_artifact, "config.json"), (labels_artifact, "labels.json")]:
mlflow.log_dict(labels_artifact, "labels.json") make_mlflow_request(
mlflow_method=mlflow.log_dict,
logger.info(f"Started MLflow run with ID ({run.info.run_id})") dictionary=artifact,
artifact_file=filename,
)
if ( if (
config["training_params"]["use_ddp"] config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"] and not config["training_params"]["force_cpu"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment