From 04ca87e23e916ca7186a4429f44137e97d9b098e Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Fri, 10 Feb 2023 15:10:30 +0100 Subject: [PATCH] patch to fix config publication --- dan/ocr/document/train.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 4bf8640a..eaea5016 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -277,6 +277,19 @@ def serialize_config(config): serialized_config["training_params"]["nb_gpu"] = str( serialized_config["training_params"]["nb_gpu"] ) + + if ( + "synthetic_data" in config["dataset_params"]["config"] + and config["dataset_params"]["config"]["synthetic_data"] + ): + serialized_config["dataset_params"]["config"]["synthetic_data"][ + "proba_scheduler_function" + ] = str( + serialized_config["dataset_params"]["config"]["synthetic_data"][ + "proba_scheduler_function" + ] + ) + return serialized_config @@ -288,10 +301,8 @@ def run(): config, dataset_name = get_config() if MLFLOW and "mlflow" in config: - config_artifact = serialize_config(config) labels_path = ( - Path(config_artifact["dataset_params"]["datasets"][dataset_name]) - / "labels.json" + Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json" ) with start_mlflow_run(config["mlflow"]) as run: logger.info(f"Started MLflow run with ID ({run.info.run_id})") @@ -305,7 +316,7 @@ def run(): # Log MLflow artifacts for artifact, filename in [ - (config_artifact, "config.json"), + (serialize_config(config), "config.json"), (labels_artifact, "labels.json"), ]: make_mlflow_request( -- GitLab