diff --git a/dan/mlflow.py b/dan/mlflow.py index 06b36fc8c502eb09fa36c5dc3506b9631171c285..b89342ed122182518c1fa0068ae45da52983fb71 100644 --- a/dan/mlflow.py +++ b/dan/mlflow.py @@ -91,7 +91,9 @@ def logging_tags_metrics( @contextmanager def start_mlflow_run(config: dict): """ - Create an MLflow execution context with the parameters contained in the config file + Create an MLflow execution context with the parameters contained in the config file. + + Yields the active MLflow run, as well as a boolean saying whether a new one was created. :param config: dict, the config of the model """ @@ -99,10 +101,22 @@ def start_mlflow_run(config: dict): # Set needed variables in environment setup_environment(config) + run_name, run_id = config.get("run_name"), config.get("run_id") + + if run_id: + logger.info(f"Will resume run ({run_id}).") + + if run_name: + logger.warning( + "Run_name will be ignored since you specified a run_id to resume from." + ) + # Set experiment from config experiment_id = config.get("experiment_id") assert experiment_id, "Missing MLflow experiment ID in the configuration" # Start run - yield mlflow.start_run(run_name=config.get("run_name"), experiment_id=experiment_id) + yield mlflow.start_run( + run_id=run_id, run_name=run_name, experiment_id=experiment_id + ), run_id is None mlflow.end_run() diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 4b005ec4bdbdd5512c7c2d37d825df573cb3f6e1..9552d042bc92a158cb29ddf9bc5314eb97d2b7f0 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -24,10 +24,12 @@ try: MLFLOW = True logger.info("MLflow Logging available.") + from dan.mlflow import make_mlflow_request, start_mlflow_run except ImportError: MLFLOW = False + logger = logging.getLogger(__name__) @@ -76,8 +78,8 @@ def get_config(): dataset_path = "." params = { "mlflow": { - "dataset_name": dataset_name, "run_name": "Test log DAN", + "run_id": None, "s3_endpoint_url": "", "tracking_uri": "", "experiment_id": "0", @@ -237,7 +239,10 @@ def get_config(): def serialize_config(config): """ - Serialize a dictionary to transform it into json and remove the credentials + Make every field of the configuration JSON-Serializable and remove sensitive information. + + - Classes are transformed using their name attribute + - Functions are casted to strings """ # Create a copy of the original config without erase it serialized_config = deepcopy(config) @@ -275,6 +280,20 @@ def serialize_config(config): serialized_config["training_params"]["nb_gpu"] = str( serialized_config["training_params"]["nb_gpu"] ) + + if ( + "synthetic_data" in config["dataset_params"]["config"] + and config["dataset_params"]["config"]["synthetic_data"] + ): + # The Probability scheduler is a function and needs to be casted to string + serialized_config["dataset_params"]["config"]["synthetic_data"][ + "proba_scheduler_function" + ] = str( + serialized_config["dataset_params"]["config"]["synthetic_data"][ + "proba_scheduler_function" + ] + ) + return serialized_config @@ -286,25 +305,25 @@ def run(): config, dataset_name = get_config() if MLFLOW and "mlflow" in config: - config_artifact = serialize_config(config) labels_path = ( - Path(config_artifact["dataset_params"]["datasets"][dataset_name]) - / "labels.json" + Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json" ) - with start_mlflow_run(config["mlflow"]) as run: - logger.info(f"Started MLflow run with ID ({run.info.run_id})") + with start_mlflow_run(config["mlflow"]) as (run, created): + if created: + logger.info(f"Started MLflow run with ID ({run.info.run_id})") + else: + logger.info(f"Resumed MLflow run with ID ({run.info.run_id})") make_mlflow_request( mlflow_method=mlflow.set_tags, tags={"Dataset": dataset_name} ) - # Get the labels json file with open(labels_path) as json_file: labels_artifact = json.load(json_file) # Log MLflow artifacts for artifact, filename in [ - (config_artifact, "config.json"), + (serialize_config(config), "config.json"), (labels_artifact, "labels.json"), ]: make_mlflow_request( @@ -312,7 +331,6 @@ def run(): dictionary=artifact, artifact_file=filename, ) - if ( config["training_params"]["use_ddp"] and not config["training_params"]["force_cpu"]