From 82d306552f76dd0ca947652bc458ee9bd17cca41 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Fri, 10 Feb 2023 11:05:10 +0100 Subject: [PATCH] do not force mlflow --- dan/mlflow.py | 6 ++++++ dan/ocr/document/train.py | 22 ++++++++++++++++++---- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/dan/mlflow.py b/dan/mlflow.py index 063c2b91..a0fba3cc 100644 --- a/dan/mlflow.py +++ b/dan/mlflow.py @@ -8,6 +8,12 @@ from mlflow.exceptions import MlflowException from dan import logger +class MLflowNotInstalled(Exception): + """ + Raised when MLflow logging was requested but the module was not installed + """ + + def setup_environment(config: dict): """ Get the necessary variables from the config file and put them in the environment variables diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index a0366aac..eff73958 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -5,20 +5,28 @@ import random from copy import deepcopy from pathlib import Path -import mlflow import numpy as np import torch import torch.multiprocessing as mp from torch.optim import Adam +from dan import logger from dan.decoder import GlobalHTADecoder from dan.manager.ocr import OCRDataset, OCRDatasetManager from dan.manager.training import Manager -from dan.mlflow import start_mlflow_run +from dan.mlflow import MLflowNotInstalled, start_mlflow_run from dan.models import FCN_Encoder from dan.schedulers import exponential_dropout_scheduler from dan.transforms import aug_config +try: + import mlflow + + MLFLOW = True + logger.info("MLflow Logging available.") +except ImportError: + MLFLOW = False + logger = logging.getLogger(__name__) @@ -272,13 +280,14 @@ def run(): config = get_config() config_artifact = serialize_config(config) labels_artifact = "" - dataset_name = config["mlflow"]["dataset_name"] + # The only key of this dict is the name of the dataset + dataset_name = config_artifact["dataset_params"]["datasets"].keys()[0] labels_path = ( Path(config_artifact["dataset_params"]["datasets"][dataset_name]) / "labels.json" ) - if config["mlflow"]: + if MLFLOW and config["mlflow"]: with start_mlflow_run(config["mlflow"]) as run: logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}") mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]}) @@ -303,6 +312,11 @@ def run(): ) else: train_and_test(0, config, True) + elif config["mlflow"]: + logger.error( + "Cannot log to MLflow as the `mlflow` module was not found in your environment." + ) + raise MLflowNotInstalled() else: if ( config["training_params"]["use_ddp"] -- GitLab