Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (3)
...@@ -8,6 +8,7 @@ import sys ...@@ -8,6 +8,7 @@ import sys
from datetime import date from datetime import date
from time import time from time import time
import mlflow
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -21,6 +22,7 @@ from torch.utils.tensorboard import SummaryWriter ...@@ -21,6 +22,7 @@ from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from dan.manager.metrics import MetricManager from dan.manager.metrics import MetricManager
from dan.mlflow import logging_metrics, logging_tags_metrics
from dan.ocr.utils import LM_ind_to_str from dan.ocr.utils import LM_ind_to_str
from dan.schedulers import DropoutScheduler from dan.schedulers import DropoutScheduler
...@@ -401,7 +403,7 @@ class GenericTrainingManager: ...@@ -401,7 +403,7 @@ class GenericTrainingManager:
"lr_schedulers" "lr_schedulers"
][key]["class"]( ][key]["class"](
self.optimizers[model_name], self.optimizers[model_name],
**self.params["training_params"]["lr_schedulers"][key]["args"] **self.params["training_params"]["lr_schedulers"][key]["args"],
) )
# Load optimizer state from past training # Load optimizer state from past training
...@@ -571,10 +573,11 @@ class GenericTrainingManager: ...@@ -571,10 +573,11 @@ class GenericTrainingManager:
def zero_optimizer(self, model_name, set_to_none=True): def zero_optimizer(self, model_name, set_to_none=True):
self.optimizers[model_name].zero_grad(set_to_none=set_to_none) self.optimizers[model_name].zero_grad(set_to_none=set_to_none)
def train(self): def train(self, mlflow_logging=False):
""" """
Main training loop Main training loop
""" """
# init tensorboard file and output param summary file # init tensorboard file and output param summary file
if self.is_master: if self.is_master:
self.writer = SummaryWriter(self.paths["results"]) self.writer = SummaryWriter(self.paths["results"])
...@@ -673,8 +676,13 @@ class GenericTrainingManager: ...@@ -673,8 +676,13 @@ class GenericTrainingManager:
pbar.set_postfix(values=str(display_values)) pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"])) pbar.update(len(batch_data["names"]))
# log metrics in tensorboard file # Log MLflow metrics
logging_metrics(
display_values, "train", num_epoch, mlflow_logging, self.is_master
)
if self.is_master: if self.is_master:
# log metrics in tensorboard file
for key in display_values.keys(): for key in display_values.keys():
self.writer.add_scalar( self.writer.add_scalar(
"{}_{}".format( "{}_{}".format(
...@@ -693,7 +701,9 @@ class GenericTrainingManager: ...@@ -693,7 +701,9 @@ class GenericTrainingManager:
): ):
for valid_set_name in self.dataset.valid_loaders.keys(): for valid_set_name in self.dataset.valid_loaders.keys():
# evaluate set and compute metrics # evaluate set and compute metrics
eval_values = self.evaluate(valid_set_name) eval_values = self.evaluate(
valid_set_name, mlflow_logging=mlflow_logging
)
self.latest_valid_metrics = eval_values self.latest_valid_metrics = eval_values
# log valid metrics in tensorboard file # log valid metrics in tensorboard file
if self.is_master: if self.is_master:
...@@ -742,7 +752,7 @@ class GenericTrainingManager: ...@@ -742,7 +752,7 @@ class GenericTrainingManager:
self.save_model(epoch=num_epoch, name="weights", keep_weights=True) self.save_model(epoch=num_epoch, name="weights", keep_weights=True)
self.writer.flush() self.writer.flush()
def evaluate(self, set_name, **kwargs): def evaluate(self, set_name, mlflow_logging=False, **kwargs):
""" """
Main loop for validation Main loop for validation
""" """
...@@ -781,11 +791,23 @@ class GenericTrainingManager: ...@@ -781,11 +791,23 @@ class GenericTrainingManager:
pbar.set_postfix(values=str(display_values)) pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"])) pbar.update(len(batch_data["names"]))
# log metrics in MLflow
logging_metrics(
display_values,
"val",
self.latest_epoch,
mlflow_logging,
self.is_master,
)
if "cer_by_nb_cols" in metric_names: if "cer_by_nb_cols" in metric_names:
self.log_cer_by_nb_cols(set_name) self.log_cer_by_nb_cols(set_name)
return display_values return display_values
def predict(self, custom_name, sets_list, metric_names, output=False): def predict(
self, custom_name, sets_list, metric_names, mlflow_logging=False, output=False
):
""" """
Main loop for evaluation Main loop for evaluation
""" """
...@@ -828,6 +850,12 @@ class GenericTrainingManager: ...@@ -828,6 +850,12 @@ class GenericTrainingManager:
pbar.set_postfix(values=str(display_values)) pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"])) pbar.update(len(batch_data["names"]))
# log metrics in MLflow
logging_name = custom_name.split("-")[1]
logging_tags_metrics(
display_values, logging_name, mlflow_logging, self.is_master
)
# output metrics values if requested # output metrics values if requested
if output: if output:
if "pred" in metric_names: if "pred" in metric_names:
...@@ -841,6 +869,9 @@ class GenericTrainingManager: ...@@ -841,6 +869,9 @@ class GenericTrainingManager:
for metric_name in metrics.keys(): for metric_name in metrics.keys():
f.write("{}: {}\n".format(metric_name, metrics[metric_name])) f.write("{}: {}\n".format(metric_name, metrics[metric_name]))
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
def output_pred(self, name): def output_pred(self, name):
path = os.path.join( path = os.path.join(
self.paths["results"], "pred_{}_{}.txt".format(name, self.latest_epoch) self.paths["results"], "pred_{}_{}.txt".format(name, self.latest_epoch)
......
# -*- coding: utf-8 -*-
import os
from contextlib import contextmanager
import mlflow
from mlflow.exceptions import MlflowException
from dan import logger
def setup_environment(config: dict):
"""
Get the necessary variables from the config file and put them in the environment variables
:param config: dict, the config of the model
"""
needed_variables = {
"MLFLOW_S3_ENDPOINT_URL": "s3_endpoint_url",
"MLFLOW_TRACKING_URI": "tracking_uri",
"AWS_ACCESS_KEY_ID": "aws_access_key_id",
"AWS_SECRET_ACCESS_KEY": "aws_secret_access_key",
}
for variable_name, config_key in needed_variables.items():
if config_key in config:
os.environ[variable_name] = config[config_key]
def logging_metrics(
display_values: dict,
step: str,
epoch: int,
mlflow_logging: bool = False,
is_master: bool = False,
):
"""
Log dictionary metrics in the Metrics section of MLflow
:param display_values: dict, the dictionary containing the metrics to publish on MLflow
:param step: str, the step for which the metrics are to be published on Metrics section (ex: train, val, test). This will allow a better display on MLflow.
:param epoch: int, the current epoch.
:param mlflow_logging: bool, allows you to verify that you have the authorization to log on MLflow, defaults to False
:param is_master: bool, makes sure you're on the right thread, defaults to False
"""
if mlflow_logging and is_master:
mlflow_values = {
f"{step}_{name}": value for name, value in display_values.items()
}
mlflow.log_metrics(mlflow_values, epoch)
def logging_tags_metrics(
display_values: dict,
step: str,
mlflow_logging: bool = False,
is_master: bool = False,
):
"""
Log dictionary metrics in the Tags section of MLflow
:param display_values: dict, the dictionary containing the metrics to publish on MLflow
:param step: str, the step for which the metrics are to be published on Tags section (ex: train, val, test). This will allow a better display on MLflow.
:param mlflow_logging: bool, allows you to verify that you have the authorization to log on MLflow, defaults to False
:param is_master: bool, makes sure you're on the right thread, defaults to False
"""
if mlflow_logging and is_master:
mlflow_values = {
f"{step}_{name}": value for name, value in display_values.items()
}
mlflow.set_tags(mlflow_values)
@contextmanager
def start_mlflow_run(config: dict):
"""
Create an MLflow execution context with the parameters contained in the config file
:param config: dict, the config of the model
"""
# Set needed variables in environment
setup_environment(config)
# Set experiment from config
experiment_id = config.get("experiment_id")
assert experiment_id, "Missing MLflow experiment ID in the configuration"
try:
mlflow.set_experiment(experiment_id=experiment_id)
logger.info(f"Run Experiment ID : {experiment_id} on MLFlow")
except MlflowException as e:
logger.error(f"Couldn't set Mlflow experiment with ID: {experiment_id}")
raise e
# Start run
yield mlflow.start_run(run_name=config.get("run_name"))
mlflow.end_run()
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json
import logging
import random import random
from copy import deepcopy
from pathlib import Path
import mlflow
import numpy as np import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -9,12 +14,15 @@ from torch.optim import Adam ...@@ -9,12 +14,15 @@ from torch.optim import Adam
from dan.decoder import GlobalHTADecoder from dan.decoder import GlobalHTADecoder
from dan.manager.ocr import OCRDataset, OCRDatasetManager from dan.manager.ocr import OCRDataset, OCRDatasetManager
from dan.manager.training import Manager from dan.manager.training import Manager
from dan.mlflow import start_mlflow_run
from dan.models import FCN_Encoder from dan.models import FCN_Encoder
from dan.schedulers import exponential_dropout_scheduler, linear_scheduler from dan.schedulers import exponential_dropout_scheduler
from dan.transforms import aug_config from dan.transforms import aug_config
logger = logging.getLogger(__name__)
def train_and_test(rank, params):
def train_and_test(rank, params, mlflow_logging=False):
torch.manual_seed(0) torch.manual_seed(0)
torch.cuda.manual_seed(0) torch.cuda.manual_seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -26,7 +34,10 @@ def train_and_test(rank, params): ...@@ -26,7 +34,10 @@ def train_and_test(rank, params):
model = Manager(params) model = Manager(params)
model.load_model() model.load_model()
model.train() if mlflow_logging:
logger.info("MLflow logging enabled")
model.train(mlflow_logging=mlflow_logging)
# load weights giving best CER on valid set # load weights giving best CER on valid set
model.params["training_params"]["load_epoch"] = "best" model.params["training_params"]["load_epoch"] = "best"
...@@ -42,21 +53,34 @@ def train_and_test(rank, params): ...@@ -42,21 +53,34 @@ def train_and_test(rank, params):
], ],
metrics, metrics,
output=True, output=True,
mlflow_logging=mlflow_logging,
) )
def run(): def get_config():
"""
Retrieve model configuration
"""
dataset_name = "esposalles" dataset_name = "esposalles"
dataset_level = "record" dataset_level = "record"
dataset_variant = "" dataset_variant = "_debug"
dataset_path = ""
params = { params = {
"mlflow": {
"dataset_name": dataset_name,
"run_name": "Test log DAN",
"s3_endpoint_url": "",
"tracking_uri": "",
"experiment_id": "9",
"aws_access_key_id": "",
"aws_secret_access_key": "",
},
"dataset_params": { "dataset_params": {
"dataset_manager": OCRDatasetManager, "dataset_manager": OCRDatasetManager,
"dataset_class": OCRDataset, "dataset_class": OCRDataset,
"datasets": { "datasets": {
dataset_name: "{}_{}{}".format( dataset_name: "{}/{}_{}{}".format(
dataset_name, dataset_level, dataset_variant dataset_path, dataset_name, dataset_level, dataset_variant
), ),
}, },
"train": { "train": {
...@@ -90,41 +114,7 @@ def run(): ...@@ -90,41 +114,7 @@ def run():
}, },
], ],
"augmentation": aug_config(0.9, 0.1), "augmentation": aug_config(0.9, 0.1),
# "synthetic_data": None, "synthetic_data": None,
"synthetic_data": {
"init_proba": 0.9, # begin proba to generate synthetic document
"end_proba": 0.2, # end proba to generate synthetic document
"num_steps_proba": 200000, # linearly decrease the percent of synthetic document from 90% to 20% through 200000 samples
"proba_scheduler_function": linear_scheduler, # decrease proba rate linearly
"start_scheduler_at_max_line": True, # start decreasing proba only after curriculum reach max number of lines
"dataset_level": dataset_level,
"curriculum": True, # use curriculum learning (slowly increase number of lines per synthetic samples)
"crop_curriculum": True, # during curriculum learning, crop images under the last text line
"curr_start": 0, # start curriculum at iteration
"curr_step": 10000, # interval to increase the number of lines for curriculum learning
"min_nb_lines": 1, # initial number of lines for curriculum learning
"max_nb_lines": 4, # maximum number of lines for curriculum learning
"padding_value": 255,
"font_path": "fonts/",
# config for synthetic line generation
"config": {
"background_color_default": (255, 255, 255),
"background_color_eps": 15,
"text_color_default": (0, 0, 0),
"text_color_eps": 15,
"font_size_min": 35,
"font_size_max": 45,
"color_mode": "RGB",
"padding_left_ratio_min": 0.00,
"padding_left_ratio_max": 0.05,
"padding_right_ratio_min": 0.02,
"padding_right_ratio_max": 0.2,
"padding_top_ratio_min": 0.02,
"padding_top_ratio_max": 0.1,
"padding_bottom_ratio_min": 0.02,
"padding_bottom_ratio_max": 0.1,
},
},
}, },
}, },
"model_params": { "model_params": {
...@@ -175,7 +165,7 @@ def run(): ...@@ -175,7 +165,7 @@ def run():
}, },
"training_params": { "training_params": {
"output_folder": "dan_esposalles_record", # folder name for checkpoint and results "output_folder": "dan_esposalles_record", # folder name for checkpoint and results
"max_nb_epochs": 50000, # maximum number of epochs before to stop "max_nb_epochs": 710, # maximum number of epochs before to stop
"max_training_time": 3600 "max_training_time": 3600
* 24 * 24
* 1.9, # maximum time before to stop (in seconds) * 1.9, # maximum time before to stop (in seconds)
...@@ -228,12 +218,100 @@ def run(): ...@@ -228,12 +218,100 @@ def run():
}, },
} }
if ( return params
params["training_params"]["use_ddp"]
and not params["training_params"]["force_cpu"]
): def serialize_config(config):
mp.spawn( """
train_and_test, args=(params,), nprocs=params["training_params"]["nb_gpu"] Serialize a dictionary to transform it into json and remove the credentials
) """
# Create a copy of the original config without erase it
serialized_config = deepcopy(config)
# Remove credentials to the config
serialized_config["mlflow"]["s3_endpoint_url"] = ""
serialized_config["mlflow"]["tracking_uri"] = ""
serialized_config["mlflow"]["aws_access_key_id"] = ""
serialized_config["mlflow"]["aws_secret_access_key"] = ""
# Get the name of the class
serialized_config["dataset_params"]["dataset_manager"] = serialized_config[
"dataset_params"
]["dataset_manager"].__name__
serialized_config["dataset_params"]["dataset_class"] = serialized_config[
"dataset_params"
]["dataset_class"].__name__
serialized_config["model_params"]["models"]["encoder"] = serialized_config[
"model_params"
]["models"]["encoder"].__name__
serialized_config["model_params"]["models"]["decoder"] = serialized_config[
"model_params"
]["models"]["decoder"].__name__
serialized_config["training_params"]["optimizers"]["all"][
"class"
] = serialized_config["training_params"]["optimizers"]["all"]["class"].__name__
# Cast the functions to str
serialized_config["dataset_params"]["config"]["augmentation"] = str(
serialized_config["dataset_params"]["config"]["augmentation"]
)
serialized_config["model_params"]["dropout_scheduler"]["function"] = str(
serialized_config["model_params"]["dropout_scheduler"]["function"]
)
serialized_config["training_params"]["nb_gpu"] = str(
serialized_config["training_params"]["nb_gpu"]
)
return serialized_config
def run():
"""
Main program, training a new model, using a valid configuration
"""
config = get_config()
config_artifact = serialize_config(config)
labels_artifact = ""
dataset_name = config["mlflow"]["dataset_name"]
labels_path = (
Path(config_artifact["dataset_params"]["datasets"][dataset_name])
/ "labels.json"
)
if config["mlflow"]:
with start_mlflow_run(config["mlflow"]) as run:
logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}")
mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]})
# Get the labels json file
with open(labels_path) as json_file:
labels_artifact = json.load(json_file)
# Log MLflow artifacts
mlflow.log_dict(config_artifact, "config.json")
mlflow.log_dict(labels_artifact, "labels.json")
logger.info(f"Started MLflow run with ID ({run.info.run_id})")
if (
config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"]
):
mp.spawn(
train_and_test,
args=(config, True),
nprocs=config["training_params"]["nb_gpu"],
)
else:
train_and_test(0, config, True)
else: else:
train_and_test(0, params) if (
config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"]
):
mp.spawn(
train_and_test,
args=(config, True),
nprocs=config["training_params"]["nb_gpu"],
)
else:
train_and_test(0, config, True)
...@@ -3,5 +3,5 @@ doc8==1.0.0 ...@@ -3,5 +3,5 @@ doc8==1.0.0
mkdocs==1.4.2 mkdocs==1.4.2
mkdocs-material==8.5.11 mkdocs-material==8.5.11
mkdocstrings==0.19.0 mkdocstrings==0.19.0
mkdocstrings-python==0.8.2 mkdocstrings-python==0.8.3
recommonmark==0.7.1 recommonmark==0.7.1
# MLflow
::: dan.mlflow
...@@ -94,6 +94,7 @@ nav: ...@@ -94,6 +94,7 @@ nav:
- Utils: ref/ocr/line/utils.md - Utils: ref/ocr/line/utils.md
- Decoders: ref/decoder.md - Decoders: ref/decoder.md
- Models: ref/models.md - Models: ref/models.md
- MLflow: ref/mlflow.md
- Post Processing: ref/post_processing.md - Post Processing: ref/post_processing.md
- Inference: ref/predict.md - Inference: ref/predict.md
- Schedulers: ref/schedulers.md - Schedulers: ref/schedulers.md
......
arkindex-client==1.0.11 arkindex-client==1.0.11
boto3==1.26.51
editdistance==0.6.1 editdistance==0.6.1
fontTools==4.38.0 fontTools==4.38.0
imageio==2.22.4 imageio==2.22.4
mlflow==2.0.1
networkx==2.8.8 networkx==2.8.8
numpy==1.23.5 numpy==1.23.5
opencv-python==4.6.0.66 opencv-python==4.6.0.66
PyYAML==6.0 PyYAML==6.0
scipy==1.9.3 scipy==1.9.3
tensorboard==2.11.0 tensorboard==2.11.0
torch==1.13.0 torch==1.13.1
torchvision==0.14.0 torchvision==0.14.1
tqdm==4.64.1 tqdm==4.64.1