Skip to content
Snippets Groups Projects
Verified Commit 82d30655 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

do not force mlflow

parent ac1f954c
No related branches found
No related tags found
2 merge requests!59Robust mlflow requests,!57Do not force mlflow usage
...@@ -8,6 +8,12 @@ from mlflow.exceptions import MlflowException ...@@ -8,6 +8,12 @@ from mlflow.exceptions import MlflowException
from dan import logger from dan import logger
class MLflowNotInstalled(Exception):
"""
Raised when MLflow logging was requested but the module was not installed
"""
def setup_environment(config: dict): def setup_environment(config: dict):
""" """
Get the necessary variables from the config file and put them in the environment variables Get the necessary variables from the config file and put them in the environment variables
......
...@@ -5,20 +5,28 @@ import random ...@@ -5,20 +5,28 @@ import random
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import mlflow
import numpy as np import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.optim import Adam from torch.optim import Adam
from dan import logger
from dan.decoder import GlobalHTADecoder from dan.decoder import GlobalHTADecoder
from dan.manager.ocr import OCRDataset, OCRDatasetManager from dan.manager.ocr import OCRDataset, OCRDatasetManager
from dan.manager.training import Manager from dan.manager.training import Manager
from dan.mlflow import start_mlflow_run from dan.mlflow import MLflowNotInstalled, start_mlflow_run
from dan.models import FCN_Encoder from dan.models import FCN_Encoder
from dan.schedulers import exponential_dropout_scheduler from dan.schedulers import exponential_dropout_scheduler
from dan.transforms import aug_config from dan.transforms import aug_config
try:
import mlflow
MLFLOW = True
logger.info("MLflow Logging available.")
except ImportError:
MLFLOW = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -272,13 +280,14 @@ def run(): ...@@ -272,13 +280,14 @@ def run():
config = get_config() config = get_config()
config_artifact = serialize_config(config) config_artifact = serialize_config(config)
labels_artifact = "" labels_artifact = ""
dataset_name = config["mlflow"]["dataset_name"] # The only key of this dict is the name of the dataset
dataset_name = config_artifact["dataset_params"]["datasets"].keys()[0]
labels_path = ( labels_path = (
Path(config_artifact["dataset_params"]["datasets"][dataset_name]) Path(config_artifact["dataset_params"]["datasets"][dataset_name])
/ "labels.json" / "labels.json"
) )
if config["mlflow"]: if MLFLOW and config["mlflow"]:
with start_mlflow_run(config["mlflow"]) as run: with start_mlflow_run(config["mlflow"]) as run:
logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}") logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}")
mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]}) mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]})
...@@ -303,6 +312,11 @@ def run(): ...@@ -303,6 +312,11 @@ def run():
) )
else: else:
train_and_test(0, config, True) train_and_test(0, config, True)
elif config["mlflow"]:
logger.error(
"Cannot log to MLflow as the `mlflow` module was not found in your environment."
)
raise MLflowNotInstalled()
else: else:
if ( if (
config["training_params"]["use_ddp"] config["training_params"]["use_ddp"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment