Skip to content
Snippets Groups Projects
Commit 18f354a2 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Bastien Abadie
Browse files

Do not force mlflow usage

parent ac1f954c
No related branches found
No related tags found
1 merge request!57Do not force mlflow usage
......@@ -8,7 +8,6 @@ import sys
from datetime import date
from time import time
import mlflow
import numpy as np
import torch
import torch.distributed as dist
......@@ -22,10 +21,16 @@ from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dan.manager.metrics import MetricManager
from dan.mlflow import logging_metrics, logging_tags_metrics
from dan.ocr.utils import LM_ind_to_str
from dan.schedulers import DropoutScheduler
try:
import mlflow
from dan.mlflow import logging_metrics, logging_tags_metrics
except ImportError:
pass
class GenericTrainingManager:
def __init__(self, params):
......
......@@ -5,19 +5,28 @@ import random
from copy import deepcopy
from pathlib import Path
import mlflow
import numpy as np
import torch
import torch.multiprocessing as mp
from torch.optim import Adam
from dan import logger
from dan.decoder import GlobalHTADecoder
from dan.manager.ocr import OCRDataset, OCRDatasetManager
from dan.manager.training import Manager
from dan.mlflow import start_mlflow_run
from dan.models import FCN_Encoder
from dan.schedulers import exponential_dropout_scheduler
from dan.transforms import aug_config
from dan.utils import MLflowNotInstalled
try:
import mlflow
MLFLOW = True
logger.info("MLflow Logging available.")
from dan.mlflow import start_mlflow_run
except ImportError:
MLFLOW = False
logger = logging.getLogger(__name__)
......@@ -62,9 +71,9 @@ def get_config():
Retrieve model configuration
"""
dataset_name = "esposalles"
dataset_level = "record"
dataset_variant = "_debug"
dataset_path = ""
dataset_level = "page"
dataset_variant = ""
dataset_path = "/home/training_data/ATR_paragraph/Esposalles"
params = {
"mlflow": {
"dataset_name": dataset_name,
......@@ -218,7 +227,7 @@ def get_config():
},
}
return params
return params, dataset_name
def serialize_config(config):
......@@ -269,16 +278,14 @@ def run():
Main program, training a new model, using a valid configuration
"""
config = get_config()
config_artifact = serialize_config(config)
labels_artifact = ""
dataset_name = config["mlflow"]["dataset_name"]
labels_path = (
Path(config_artifact["dataset_params"]["datasets"][dataset_name])
/ "labels.json"
)
config, dataset_name = get_config()
if config["mlflow"]:
if MLFLOW and "mlflow" in config:
config_artifact = serialize_config(config)
labels_path = (
Path(config_artifact["dataset_params"]["datasets"][dataset_name])
/ "labels.json"
)
with start_mlflow_run(config["mlflow"]) as run:
logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}")
mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]})
......@@ -303,6 +310,11 @@ def run():
)
else:
train_and_test(0, config, True)
elif "mlflow" in config:
logger.error(
"Cannot log to MLflow as the `mlflow` module was not found in your environment."
)
raise MLflowNotInstalled()
else:
if (
config["training_params"]["use_ddp"]
......
......@@ -18,6 +18,12 @@ SEM_MATCHING_TOKENS_STR = {
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
class MLflowNotInstalled(Exception):
"""
Raised when MLflow logging was requested but the module was not installed
"""
def randint(low, high):
"""
call torch.randint to preserve random among dataloader workers
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment