diff --git a/dan/manager/training.py b/dan/manager/training.py index a1c3b1e43f57c7e22da6680cd162f2550ac7d79b..7b565c41e9465531821f8af5ed739d1794670559 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -31,13 +31,6 @@ try: except ImportError: pass -try: - import mlflow - - from dan.mlflow import logging_metrics, logging_tags_metrics -except ImportError: - pass - class GenericTrainingManager: def __init__(self, params): diff --git a/dan/mlflow.py b/dan/mlflow.py index 95b41bf8eacb2ae8efb4fd865e93f66abf1e39bf..b89342ed122182518c1fa0068ae45da52983fb71 100644 --- a/dan/mlflow.py +++ b/dan/mlflow.py @@ -91,7 +91,9 @@ def logging_tags_metrics( @contextmanager def start_mlflow_run(config: dict): """ - Create an MLflow execution context with the parameters contained in the config file + Create an MLflow execution context with the parameters contained in the config file. + + Yields the active MLflow run, as well as a boolean saying whether a new one was created. :param config: dict, the config of the model """ @@ -116,5 +118,5 @@ def start_mlflow_run(config: dict): # Start run yield mlflow.start_run( run_id=run_id, run_name=run_name, experiment_id=experiment_id - ) + ), run_id is None mlflow.end_run() diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index eaea50168a53bc6a37a04c2f38ec29f2a24dcc57..3eb15552ba975bbb294a00fd625a21ad7f1e4592 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -304,11 +304,14 @@ def run(): labels_path = ( Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json" ) - with start_mlflow_run(config["mlflow"]) as run: - logger.info(f"Started MLflow run with ID ({run.info.run_id})") + with start_mlflow_run(config["mlflow"]) as (run, created): + if created: + logger.info(f"Started MLflow run with ID ({run.info.run_id})") + else: + logger.info(f"Resumed MLflow run with ID ({run.info.run_id})") + make_mlflow_request( - mlflow_method=mlflow.set_tags, - tags={"Dataset": dataset_name} + mlflow_method=mlflow.set_tags, tags={"Dataset": dataset_name} ) # Get the labels json file with open(labels_path) as json_file: