Skip to content
Snippets Groups Projects

Resume from existing mlflow run

Merged Yoann Schneider requested to merge resume-from-existing-mlflow-run into main
All threads resolved!
2 files
+ 44
12
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 28
10
@@ -24,10 +24,12 @@ try:
MLFLOW = True
logger.info("MLflow Logging available.")
from dan.mlflow import make_mlflow_request, start_mlflow_run
except ImportError:
MLFLOW = False
logger = logging.getLogger(__name__)
@@ -76,8 +78,8 @@ def get_config():
dataset_path = "."
params = {
"mlflow": {
"dataset_name": dataset_name,
"run_name": "Test log DAN",
"run_id": None,
"s3_endpoint_url": "",
"tracking_uri": "",
"experiment_id": "0",
@@ -237,7 +239,10 @@ def get_config():
def serialize_config(config):
"""
Serialize a dictionary to transform it into json and remove the credentials
Make every field of the configuration JSON-Serializable and remove sensitive information.
- Classes are transformed using their name attribute
- Functions are casted to strings
"""
# Create a copy of the original config without erase it
serialized_config = deepcopy(config)
@@ -275,6 +280,20 @@ def serialize_config(config):
serialized_config["training_params"]["nb_gpu"] = str(
serialized_config["training_params"]["nb_gpu"]
)
if (
"synthetic_data" in config["dataset_params"]["config"]
and config["dataset_params"]["config"]["synthetic_data"]
):
# The Probability scheduler is a function and needs to be casted to string
serialized_config["dataset_params"]["config"]["synthetic_data"][
"proba_scheduler_function"
] = str(
serialized_config["dataset_params"]["config"]["synthetic_data"][
"proba_scheduler_function"
]
)
return serialized_config
@@ -286,25 +305,25 @@ def run():
config, dataset_name = get_config()
if MLFLOW and "mlflow" in config:
config_artifact = serialize_config(config)
labels_path = (
Path(config_artifact["dataset_params"]["datasets"][dataset_name])
/ "labels.json"
Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
)
with start_mlflow_run(config["mlflow"]) as run:
logger.info(f"Started MLflow run with ID ({run.info.run_id})")
with start_mlflow_run(config["mlflow"]) as (run, created):
if created:
logger.info(f"Started MLflow run with ID ({run.info.run_id})")
else:
logger.info(f"Resumed MLflow run with ID ({run.info.run_id})")
make_mlflow_request(
mlflow_method=mlflow.set_tags, tags={"Dataset": dataset_name}
)
# Get the labels json file
with open(labels_path) as json_file:
labels_artifact = json.load(json_file)
# Log MLflow artifacts
for artifact, filename in [
(config_artifact, "config.json"),
(serialize_config(config), "config.json"),
(labels_artifact, "labels.json"),
]:
make_mlflow_request(
@@ -312,7 +331,6 @@ def run():
dictionary=artifact,
artifact_file=filename,
)
if (
config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"]
Loading