diff --git a/dan/manager/training.py b/dan/manager/training.py index 4c887a5c956b8d60c878e535ccce8f2e147a9e3a..d7b4897c50645c5254741d1765bab35d13bdea81 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -8,7 +8,6 @@ import sys from datetime import date from time import time -import mlflow import numpy as np import torch import torch.distributed as dist @@ -22,9 +21,13 @@ from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from dan.manager.metrics import MetricManager -from dan.mlflow import logging_metrics, logging_tags_metrics from dan.ocr.utils import LM_ind_to_str from dan.schedulers import DropoutScheduler +try: + import mlflow + from dan.mlflow import logging_metrics, logging_tags_metrics +except ImportError: + pass class GenericTrainingManager: diff --git a/dan/mlflow.py b/dan/mlflow.py index a0fba3ccb14f2e3fe8a84b19d5d13c9356394198..20b3436e13076580ca4c4cfc82901814928ae741 100644 --- a/dan/mlflow.py +++ b/dan/mlflow.py @@ -8,11 +8,6 @@ from mlflow.exceptions import MlflowException from dan import logger -class MLflowNotInstalled(Exception): - """ - Raised when MLflow logging was requested but the module was not installed - """ - def setup_environment(config: dict): """ diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index eff73958ad3cf3180e241b179e1147b71c83a1da..6a392dc2de3f476d5ac8d07b880345498775cea8 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -14,16 +14,17 @@ from dan import logger from dan.decoder import GlobalHTADecoder from dan.manager.ocr import OCRDataset, OCRDatasetManager from dan.manager.training import Manager -from dan.mlflow import MLflowNotInstalled, start_mlflow_run from dan.models import FCN_Encoder from dan.schedulers import exponential_dropout_scheduler from dan.transforms import aug_config +from dan.utils import MLflowNotInstalled try: import mlflow MLFLOW = True logger.info("MLflow Logging available.") + from dan.mlflow import start_mlflow_run except ImportError: MLFLOW = False @@ -70,9 +71,9 @@ def get_config(): Retrieve model configuration """ dataset_name = "esposalles" - dataset_level = "record" - dataset_variant = "_debug" - dataset_path = "" + dataset_level = "page" + dataset_variant = "" + dataset_path = "/home/training_data/ATR_paragraph/Esposalles" params = { "mlflow": { "dataset_name": dataset_name, @@ -226,7 +227,7 @@ def get_config(): }, } - return params + return params, dataset_name def serialize_config(config): @@ -277,17 +278,14 @@ def run(): Main program, training a new model, using a valid configuration """ - config = get_config() - config_artifact = serialize_config(config) - labels_artifact = "" - # The only key of this dict is the name of the dataset - dataset_name = config_artifact["dataset_params"]["datasets"].keys()[0] - labels_path = ( - Path(config_artifact["dataset_params"]["datasets"][dataset_name]) - / "labels.json" - ) + config, dataset_name = get_config() - if MLFLOW and config["mlflow"]: + if MLFLOW and "mlflow" in config: + config_artifact = serialize_config(config) + labels_path = ( + Path(config_artifact["dataset_params"]["datasets"][dataset_name]) + / "labels.json" + ) with start_mlflow_run(config["mlflow"]) as run: logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}") mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]}) @@ -312,7 +310,7 @@ def run(): ) else: train_and_test(0, config, True) - elif config["mlflow"]: + elif "mlflow" in config: logger.error( "Cannot log to MLflow as the `mlflow` module was not found in your environment." ) diff --git a/dan/utils.py b/dan/utils.py index 8d6f4b4f4407b18af58ba09265b5d6639ccbf858..ad8d3ca4c881bed270ea687486585fa495387a3b 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -18,6 +18,11 @@ SEM_MATCHING_TOKENS_STR = { SEM_MATCHING_TOKENS = {"ⓘ": "â’¾", "â““": "â’¹", "â“¢": "Ⓢ", "â“’": "â’¸", "â“Ÿ": "â“…", "â“": "â’¶"} +class MLflowNotInstalled(Exception): + """ + Raised when MLflow logging was requested but the module was not installed + """ + def randint(low, high): """ call torch.randint to preserve random among dataloader workers