diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 4bf8640a76969dcf9e65ae24dfd1053511bcdc87..eaea50168a53bc6a37a04c2f38ec29f2a24dcc57 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -277,6 +277,19 @@ def serialize_config(config): serialized_config["training_params"]["nb_gpu"] = str( serialized_config["training_params"]["nb_gpu"] ) + + if ( + "synthetic_data" in config["dataset_params"]["config"] + and config["dataset_params"]["config"]["synthetic_data"] + ): + serialized_config["dataset_params"]["config"]["synthetic_data"][ + "proba_scheduler_function" + ] = str( + serialized_config["dataset_params"]["config"]["synthetic_data"][ + "proba_scheduler_function" + ] + ) + return serialized_config @@ -288,10 +301,8 @@ def run(): config, dataset_name = get_config() if MLFLOW and "mlflow" in config: - config_artifact = serialize_config(config) labels_path = ( - Path(config_artifact["dataset_params"]["datasets"][dataset_name]) - / "labels.json" + Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json" ) with start_mlflow_run(config["mlflow"]) as run: logger.info(f"Started MLflow run with ID ({run.info.run_id})") @@ -305,7 +316,7 @@ def run(): # Log MLflow artifacts for artifact, filename in [ - (config_artifact, "config.json"), + (serialize_config(config), "config.json"), (labels_artifact, "labels.json"), ]: make_mlflow_request(