Skip to content
Snippets Groups Projects
Commit 18f354a2 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Bastien Abadie
Browse files

Do not force mlflow usage

parent ac1f954c
No related branches found
No related tags found
1 merge request!57Do not force mlflow usage
...@@ -8,7 +8,6 @@ import sys ...@@ -8,7 +8,6 @@ import sys
from datetime import date from datetime import date
from time import time from time import time
import mlflow
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -22,10 +21,16 @@ from torch.utils.tensorboard import SummaryWriter ...@@ -22,10 +21,16 @@ from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from dan.manager.metrics import MetricManager from dan.manager.metrics import MetricManager
from dan.mlflow import logging_metrics, logging_tags_metrics
from dan.ocr.utils import LM_ind_to_str from dan.ocr.utils import LM_ind_to_str
from dan.schedulers import DropoutScheduler from dan.schedulers import DropoutScheduler
try:
import mlflow
from dan.mlflow import logging_metrics, logging_tags_metrics
except ImportError:
pass
class GenericTrainingManager: class GenericTrainingManager:
def __init__(self, params): def __init__(self, params):
......
...@@ -5,19 +5,28 @@ import random ...@@ -5,19 +5,28 @@ import random
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import mlflow
import numpy as np import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.optim import Adam from torch.optim import Adam
from dan import logger
from dan.decoder import GlobalHTADecoder from dan.decoder import GlobalHTADecoder
from dan.manager.ocr import OCRDataset, OCRDatasetManager from dan.manager.ocr import OCRDataset, OCRDatasetManager
from dan.manager.training import Manager from dan.manager.training import Manager
from dan.mlflow import start_mlflow_run
from dan.models import FCN_Encoder from dan.models import FCN_Encoder
from dan.schedulers import exponential_dropout_scheduler from dan.schedulers import exponential_dropout_scheduler
from dan.transforms import aug_config from dan.transforms import aug_config
from dan.utils import MLflowNotInstalled
try:
import mlflow
MLFLOW = True
logger.info("MLflow Logging available.")
from dan.mlflow import start_mlflow_run
except ImportError:
MLFLOW = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -62,9 +71,9 @@ def get_config(): ...@@ -62,9 +71,9 @@ def get_config():
Retrieve model configuration Retrieve model configuration
""" """
dataset_name = "esposalles" dataset_name = "esposalles"
dataset_level = "record" dataset_level = "page"
dataset_variant = "_debug" dataset_variant = ""
dataset_path = "" dataset_path = "/home/training_data/ATR_paragraph/Esposalles"
params = { params = {
"mlflow": { "mlflow": {
"dataset_name": dataset_name, "dataset_name": dataset_name,
...@@ -218,7 +227,7 @@ def get_config(): ...@@ -218,7 +227,7 @@ def get_config():
}, },
} }
return params return params, dataset_name
def serialize_config(config): def serialize_config(config):
...@@ -269,16 +278,14 @@ def run(): ...@@ -269,16 +278,14 @@ def run():
Main program, training a new model, using a valid configuration Main program, training a new model, using a valid configuration
""" """
config = get_config() config, dataset_name = get_config()
config_artifact = serialize_config(config)
labels_artifact = ""
dataset_name = config["mlflow"]["dataset_name"]
labels_path = (
Path(config_artifact["dataset_params"]["datasets"][dataset_name])
/ "labels.json"
)
if config["mlflow"]: if MLFLOW and "mlflow" in config:
config_artifact = serialize_config(config)
labels_path = (
Path(config_artifact["dataset_params"]["datasets"][dataset_name])
/ "labels.json"
)
with start_mlflow_run(config["mlflow"]) as run: with start_mlflow_run(config["mlflow"]) as run:
logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}") logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}")
mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]}) mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]})
...@@ -303,6 +310,11 @@ def run(): ...@@ -303,6 +310,11 @@ def run():
) )
else: else:
train_and_test(0, config, True) train_and_test(0, config, True)
elif "mlflow" in config:
logger.error(
"Cannot log to MLflow as the `mlflow` module was not found in your environment."
)
raise MLflowNotInstalled()
else: else:
if ( if (
config["training_params"]["use_ddp"] config["training_params"]["use_ddp"]
......
...@@ -18,6 +18,12 @@ SEM_MATCHING_TOKENS_STR = { ...@@ -18,6 +18,12 @@ SEM_MATCHING_TOKENS_STR = {
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""} SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
class MLflowNotInstalled(Exception):
"""
Raised when MLflow logging was requested but the module was not installed
"""
def randint(low, high): def randint(low, high):
""" """
call torch.randint to preserve random among dataloader workers call torch.randint to preserve random among dataloader workers
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment