diff --git a/dan/manager/training.py b/dan/manager/training.py index d7b4897c50645c5254741d1765bab35d13bdea81..7b565c41e9465531821f8af5ed739d1794670559 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -23,8 +23,10 @@ from tqdm import tqdm from dan.manager.metrics import MetricManager from dan.ocr.utils import LM_ind_to_str from dan.schedulers import DropoutScheduler + try: import mlflow + from dan.mlflow import logging_metrics, logging_tags_metrics except ImportError: pass diff --git a/dan/mlflow.py b/dan/mlflow.py index 20b3436e13076580ca4c4cfc82901814928ae741..063c2b91d96e81f62914b0d14fa14596473b481c 100644 --- a/dan/mlflow.py +++ b/dan/mlflow.py @@ -8,7 +8,6 @@ from mlflow.exceptions import MlflowException from dan import logger - def setup_environment(config: dict): """ Get the necessary variables from the config file and put them in the environment variables diff --git a/dan/utils.py b/dan/utils.py index ad8d3ca4c881bed270ea687486585fa495387a3b..50f7311d602c97e80c308637e1b8cb37d8c90f95 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -23,6 +23,7 @@ class MLflowNotInstalled(Exception): Raised when MLflow logging was requested but the module was not installed """ + def randint(low, high): """ call torch.randint to preserve random among dataloader workers