diff --git a/dan/mlflow.py b/dan/mlflow.py index 9fdd4a6512ef198946cfb5fa438397a37f6795f2..55efde338a001d7efc3d1fe0608e971a8d900805 100644 --- a/dan/mlflow.py +++ b/dan/mlflow.py @@ -91,10 +91,22 @@ def start_mlflow_run(config: dict): # Set needed variables in environment setup_environment(config) + run_name, run_id = config.get("run_name"), config.get("run_id") + + if run_id: + logger.info(f"Will resume run ({run_id}).") + + if run_name: + logger.warning( + "Run_name will be ignored since you specified a run_id to resume from." + ) + # Set experiment from config experiment_id = config.get("experiment_id") assert experiment_id, "Missing MLflow experiment ID in the configuration" # Start run - yield mlflow.start_run(run_name=config.get("run_name"), experiment_id=experiment_id) + yield mlflow.start_run( + run_id=run_id, run_name=run_name, experiment_id=experiment_id + ) mlflow.end_run() diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 72abb8b8426cdd2f65ef7cefbfc21fe7afc83e92..04eb24ad38dc9cd9db9674929ae9490b4f27e2d7 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -76,8 +76,8 @@ def get_config(): dataset_path = "." params = { "mlflow": { - "dataset_name": dataset_name, "run_name": "Test log DAN", + "run_id": None, "s3_endpoint_url": "", "tracking_uri": "", "experiment_id": "0",