Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (2)
......@@ -3,11 +3,22 @@ import os
from contextlib import contextmanager
import mlflow
from mlflow.exceptions import MlflowException
import requests
from mlflow.environment_variables import MLFLOW_HTTP_REQUEST_MAX_RETRIES
from dan import logger
def make_mlflow_request(mlflow_method, *args, **kwargs):
"""
Encapsulate MLflow HTTP requests to prevent them from crashing the whole training process.
"""
try:
mlflow_method(*args, **kwargs)
except requests.exceptions.ConnectionError as e:
logger.error(f"Call to `{str(mlflow_method)}` failed with error: {str(e)}")
def setup_environment(config: dict):
"""
Get the necessary variables from the config file and put them in the environment variables
......@@ -24,6 +35,13 @@ def setup_environment(config: dict):
if config_key in config:
os.environ[variable_name] = config[config_key]
# Check max retry setting
max_retries = MLFLOW_HTTP_REQUEST_MAX_RETRIES.get()
if max_retries and int(max_retries) <= 1:
logger.warning(
f"The maximum number of retries for MLflow HTTP requests is set to {max_retries}, which is low. Consider using a higher value."
)
def logging_metrics(
display_values: dict,
......@@ -42,10 +60,11 @@ def logging_metrics(
:param is_master: bool, makes sure you're on the right thread, defaults to False
"""
if mlflow_logging and is_master:
mlflow_values = {
f"{step}_{name}": value for name, value in display_values.items()
}
mlflow.log_metrics(mlflow_values, epoch)
make_mlflow_request(
mlflow_method=mlflow.log_metrics,
metrics={f"{step}_{name}": value for name, value in display_values.items()},
step=epoch,
)
def logging_tags_metrics(
......@@ -63,16 +82,18 @@ def logging_tags_metrics(
:param is_master: bool, makes sure you're on the right thread, defaults to False
"""
if mlflow_logging and is_master:
mlflow_values = {
f"{step}_{name}": value for name, value in display_values.items()
}
mlflow.set_tags(mlflow_values)
make_mlflow_request(
mlflow_method=mlflow.set_tags,
tags={f"{step}_{name}": value for name, value in display_values.items()},
)
@contextmanager
def start_mlflow_run(config: dict):
"""
Create an MLflow execution context with the parameters contained in the config file
Create an MLflow execution context with the parameters contained in the config file.
Yields the active MLflow run, as well as a boolean saying whether a new one was created.
:param config: dict, the config of the model
"""
......@@ -80,16 +101,22 @@ def start_mlflow_run(config: dict):
# Set needed variables in environment
setup_environment(config)
run_name, run_id = config.get("run_name"), config.get("run_id")
if run_id:
logger.info(f"Will resume run ({run_id}).")
if run_name:
logger.warning(
"Run_name will be ignored since you specified a run_id to resume from."
)
# Set experiment from config
experiment_id = config.get("experiment_id")
assert experiment_id, "Missing MLflow experiment ID in the configuration"
try:
mlflow.set_experiment(experiment_id=experiment_id)
logger.info(f"Run Experiment ID : {experiment_id} on MLFlow")
except MlflowException as e:
logger.error(f"Couldn't set Mlflow experiment with ID: {experiment_id}")
raise e
# Start run
yield mlflow.start_run(run_name=config.get("run_name"))
yield mlflow.start_run(
run_id=run_id, run_name=run_name, experiment_id=experiment_id
), run_id is None
mlflow.end_run()
......@@ -24,10 +24,12 @@ try:
MLFLOW = True
logger.info("MLflow Logging available.")
from dan.mlflow import start_mlflow_run
from dan.mlflow import make_mlflow_request, start_mlflow_run
except ImportError:
MLFLOW = False
logger = logging.getLogger(__name__)
......@@ -71,16 +73,16 @@ def get_config():
Retrieve model configuration
"""
dataset_name = "esposalles"
dataset_level = "page"
dataset_variant = ""
dataset_path = "/home/training_data/ATR_paragraph/Esposalles"
dataset_level = "record"
dataset_variant = "_debug"
dataset_path = "."
params = {
"mlflow": {
"dataset_name": dataset_name,
"run_name": "Test log DAN",
"run_id": None,
"s3_endpoint_url": "",
"tracking_uri": "",
"experiment_id": "9",
"experiment_id": "0",
"aws_access_key_id": "",
"aws_secret_access_key": "",
},
......@@ -103,6 +105,11 @@ def get_config():
(dataset_name, "val"),
],
},
"test": {
"{}-test".format(dataset_name): [
(dataset_name, "test"),
],
},
"config": {
"load_in_memory": True, # Load all images in CPU memory
"worker_per_gpu": 4, # Num of parallel processes per gpu for data loading
......@@ -232,7 +239,10 @@ def get_config():
def serialize_config(config):
"""
Serialize a dictionary to transform it into json and remove the credentials
Make every field of the configuration JSON-Serializable and remove sensitive information.
- Classes are transformed using their name attribute
- Functions are casted to strings
"""
# Create a copy of the original config without erase it
serialized_config = deepcopy(config)
......@@ -270,6 +280,20 @@ def serialize_config(config):
serialized_config["training_params"]["nb_gpu"] = str(
serialized_config["training_params"]["nb_gpu"]
)
if (
"synthetic_data" in config["dataset_params"]["config"]
and config["dataset_params"]["config"]["synthetic_data"]
):
# The Probability scheduler is a function and needs to be casted to string
serialized_config["dataset_params"]["config"]["synthetic_data"][
"proba_scheduler_function"
] = str(
serialized_config["dataset_params"]["config"]["synthetic_data"][
"proba_scheduler_function"
]
)
return serialized_config
......@@ -281,24 +305,32 @@ def run():
config, dataset_name = get_config()
if MLFLOW and "mlflow" in config:
config_artifact = serialize_config(config)
labels_path = (
Path(config_artifact["dataset_params"]["datasets"][dataset_name])
/ "labels.json"
Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
)
with start_mlflow_run(config["mlflow"]) as run:
logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}")
mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]})
with start_mlflow_run(config["mlflow"]) as (run, created):
if created:
logger.info(f"Started MLflow run with ID ({run.info.run_id})")
else:
logger.info(f"Resumed MLflow run with ID ({run.info.run_id})")
make_mlflow_request(
mlflow_method=mlflow.set_tags, tags={"Dataset": dataset_name}
)
# Get the labels json file
with open(labels_path) as json_file:
labels_artifact = json.load(json_file)
# Log MLflow artifacts
mlflow.log_dict(config_artifact, "config.json")
mlflow.log_dict(labels_artifact, "labels.json")
logger.info(f"Started MLflow run with ID ({run.info.run_id})")
for artifact, filename in [
(serialize_config(config), "config.json"),
(labels_artifact, "labels.json"),
]:
make_mlflow_request(
mlflow_method=mlflow.log_dict,
dictionary=artifact,
artifact_file=filename,
)
if (
config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"]
......