diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 04eb24ad38dc9cd9db9674929ae9490b4f27e2d7..63a0051778baaa74e70ed66d9c5f0013f40a908e 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -270,6 +270,19 @@ def serialize_config(config): serialized_config["training_params"]["nb_gpu"] = str( serialized_config["training_params"]["nb_gpu"] ) + + if ( + "synthetic_data" in config["dataset_params"]["config"] + and config["dataset_params"]["config"]["synthetic_data"] + ): + serialized_config["dataset_params"]["config"]["synthetic_data"][ + "proba_scheduler_function" + ] = str( + serialized_config["dataset_params"]["config"]["synthetic_data"][ + "proba_scheduler_function" + ] + ) + return serialized_config @@ -281,10 +294,8 @@ def run(): config, dataset_name = get_config() if MLFLOW and "mlflow" in config: - config_artifact = serialize_config(config) labels_path = ( - Path(config_artifact["dataset_params"]["datasets"][dataset_name]) - / "labels.json" + Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json" ) with start_mlflow_run(config["mlflow"]) as run: logger.info(f"Started MLflow run with ID ({run.info.run_id})") @@ -299,7 +310,7 @@ def run(): # Log MLflow artifacts for artifact, filename in [ - (config_artifact, "config.json"), + (serialize_config(config), "config.json"), (labels_artifact, "labels.json"), ]: make_mlflow_request(