Skip to content
Snippets Groups Projects
Commit 5a7dd7ef authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Correctly set mlflow logging

parent a5f7b5ab
No related branches found
No related tags found
1 merge request!185Correctly set mlflow logging
...@@ -258,18 +258,18 @@ def serialize_config(config): ...@@ -258,18 +258,18 @@ def serialize_config(config):
return serialized_config return serialized_config
def start_training(config) -> None: def start_training(config, mlflow_logging: bool) -> None:
if ( if (
config["training_params"]["use_ddp"] config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"] and not config["training_params"]["force_cpu"]
): ):
mp.spawn( mp.spawn(
train_and_test, train_and_test,
args=(config, True), args=(config, mlflow_logging),
nprocs=config["training_params"]["nb_gpu"], nprocs=config["training_params"]["nb_gpu"],
) )
else: else:
train_and_test(0, config, True) train_and_test(0, config, mlflow_logging)
def run(): def run():
...@@ -286,7 +286,7 @@ def run(): ...@@ -286,7 +286,7 @@ def run():
raise MLflowNotInstalled() raise MLflowNotInstalled()
if "mlflow" not in config: if "mlflow" not in config:
start_training(config) start_training(config, mlflow_logging=False)
else: else:
labels_path = ( labels_path = (
Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json" Path(config["dataset_params"]["datasets"][dataset_name]) / "labels.json"
...@@ -314,4 +314,4 @@ def run(): ...@@ -314,4 +314,4 @@ def run():
dictionary=artifact, dictionary=artifact,
artifact_file=filename, artifact_file=filename,
) )
start_training(config) start_training(config, mlflow_logging=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment