From 04ca87e23e916ca7186a4429f44137e97d9b098e Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Fri, 10 Feb 2023 15:10:30 +0100
Subject: [PATCH] patch to fix config publication

---
 dan/ocr/document/train.py | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index 4bf8640a..eaea5016 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -277,6 +277,19 @@ def serialize_config(config):
     serialized_config["training_params"]["nb_gpu"] = str(
         serialized_config["training_params"]["nb_gpu"]
     )
+
+    if (
+        "synthetic_data" in config["dataset_params"]["config"]
+        and config["dataset_params"]["config"]["synthetic_data"]
+    ):
+        serialized_config["dataset_params"]["config"]["synthetic_data"][
+            "proba_scheduler_function"
+        ] = str(
+            serialized_config["dataset_params"]["config"]["synthetic_data"][
+                "proba_scheduler_function"
+            ]
+        )
+
     return serialized_config
 
 
@@ -288,10 +301,8 @@ def run():
     config, dataset_name = get_config()
 
     if MLFLOW and "mlflow" in config:
-        config_artifact = serialize_config(config)
         labels_path = (
-            Path(config_artifact["dataset_params"]["datasets"][dataset_name])
-            / "labels.json"
+            Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
         )
         with start_mlflow_run(config["mlflow"]) as run:
             logger.info(f"Started MLflow run with ID ({run.info.run_id})")            
@@ -305,7 +316,7 @@ def run():
 
             # Log MLflow artifacts
             for artifact, filename in [
-                (config_artifact, "config.json"),
+                (serialize_config(config), "config.json"),
                 (labels_artifact, "labels.json"),
             ]:
                 make_mlflow_request(
-- 
GitLab