diff --git a/dan/mlflow.py b/dan/mlflow.py index 063c2b91d96e81f62914b0d14fa14596473b481c..a0fba3ccb14f2e3fe8a84b19d5d13c9356394198 100644 --- a/dan/mlflow.py +++ b/dan/mlflow.py @@ -8,6 +8,12 @@ from mlflow.exceptions import MlflowException from dan import logger +class MLflowNotInstalled(Exception): + """ + Raised when MLflow logging was requested but the module was not installed + """ + + def setup_environment(config: dict): """ Get the necessary variables from the config file and put them in the environment variables diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index a0366aac67e5def6c8da13b2d1598f96546e0deb..eff73958ad3cf3180e241b179e1147b71c83a1da 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -5,20 +5,28 @@ import random from copy import deepcopy from pathlib import Path -import mlflow import numpy as np import torch import torch.multiprocessing as mp from torch.optim import Adam +from dan import logger from dan.decoder import GlobalHTADecoder from dan.manager.ocr import OCRDataset, OCRDatasetManager from dan.manager.training import Manager -from dan.mlflow import start_mlflow_run +from dan.mlflow import MLflowNotInstalled, start_mlflow_run from dan.models import FCN_Encoder from dan.schedulers import exponential_dropout_scheduler from dan.transforms import aug_config +try: + import mlflow + + MLFLOW = True + logger.info("MLflow Logging available.") +except ImportError: + MLFLOW = False + logger = logging.getLogger(__name__) @@ -272,13 +280,14 @@ def run(): config = get_config() config_artifact = serialize_config(config) labels_artifact = "" - dataset_name = config["mlflow"]["dataset_name"] + # The only key of this dict is the name of the dataset + dataset_name = config_artifact["dataset_params"]["datasets"].keys()[0] labels_path = ( Path(config_artifact["dataset_params"]["datasets"][dataset_name]) / "labels.json" ) - if config["mlflow"]: + if MLFLOW and config["mlflow"]: with start_mlflow_run(config["mlflow"]) as run: logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}") mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]}) @@ -303,6 +312,11 @@ def run(): ) else: train_and_test(0, config, True) + elif config["mlflow"]: + logger.error( + "Cannot log to MLflow as the `mlflow` module was not found in your environment." + ) + raise MLflowNotInstalled() else: if ( config["training_params"]["use_ddp"]