Skip to content
Snippets Groups Projects
Commit bf9cf447 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

check if mlflow is available at runtime

parent 82d30655
No related branches found
No related tags found
2 merge requests!59Robust mlflow requests,!57Do not force mlflow usage
This commit is part of merge request !57. Comments created here will be created in the context of that merge request.
......@@ -8,7 +8,6 @@ import sys
from datetime import date
from time import time
import mlflow
import numpy as np
import torch
import torch.distributed as dist
......@@ -22,9 +21,13 @@ from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dan.manager.metrics import MetricManager
from dan.mlflow import logging_metrics, logging_tags_metrics
from dan.ocr.utils import LM_ind_to_str
from dan.schedulers import DropoutScheduler
try:
import mlflow
from dan.mlflow import logging_metrics, logging_tags_metrics
except ImportError:
pass
class GenericTrainingManager:
......
......@@ -8,11 +8,6 @@ from mlflow.exceptions import MlflowException
from dan import logger
class MLflowNotInstalled(Exception):
"""
Raised when MLflow logging was requested but the module was not installed
"""
def setup_environment(config: dict):
"""
......
......@@ -14,16 +14,17 @@ from dan import logger
from dan.decoder import GlobalHTADecoder
from dan.manager.ocr import OCRDataset, OCRDatasetManager
from dan.manager.training import Manager
from dan.mlflow import MLflowNotInstalled, start_mlflow_run
from dan.models import FCN_Encoder
from dan.schedulers import exponential_dropout_scheduler
from dan.transforms import aug_config
from dan.utils import MLflowNotInstalled
try:
import mlflow
MLFLOW = True
logger.info("MLflow Logging available.")
from dan.mlflow import start_mlflow_run
except ImportError:
MLFLOW = False
......@@ -70,9 +71,9 @@ def get_config():
Retrieve model configuration
"""
dataset_name = "esposalles"
dataset_level = "record"
dataset_variant = "_debug"
dataset_path = ""
dataset_level = "page"
dataset_variant = ""
dataset_path = "/home/training_data/ATR_paragraph/Esposalles"
params = {
"mlflow": {
"dataset_name": dataset_name,
......@@ -226,7 +227,7 @@ def get_config():
},
}
return params
return params, dataset_name
def serialize_config(config):
......@@ -277,17 +278,14 @@ def run():
Main program, training a new model, using a valid configuration
"""
config = get_config()
config_artifact = serialize_config(config)
labels_artifact = ""
# The only key of this dict is the name of the dataset
dataset_name = config_artifact["dataset_params"]["datasets"].keys()[0]
labels_path = (
Path(config_artifact["dataset_params"]["datasets"][dataset_name])
/ "labels.json"
)
config, dataset_name = get_config()
if MLFLOW and config["mlflow"]:
if MLFLOW and "mlflow" in config:
config_artifact = serialize_config(config)
labels_path = (
Path(config_artifact["dataset_params"]["datasets"][dataset_name])
/ "labels.json"
)
with start_mlflow_run(config["mlflow"]) as run:
logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}")
mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]})
......@@ -312,7 +310,7 @@ def run():
)
else:
train_and_test(0, config, True)
elif config["mlflow"]:
elif "mlflow" in config:
logger.error(
"Cannot log to MLflow as the `mlflow` module was not found in your environment."
)
......
......@@ -18,6 +18,11 @@ SEM_MATCHING_TOKENS_STR = {
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
class MLflowNotInstalled(Exception):
"""
Raised when MLflow logging was requested but the module was not installed
"""
def randint(low, high):
"""
call torch.randint to preserve random among dataloader workers
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment