Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (2)
......@@ -22,6 +22,7 @@ def save_json(path, dict):
def insert_token(text, count, start_token, end_token, offset, length):
"""
Insert the given tokens at the right position in the text
start_token or end_token can be empty strings
"""
text = (
# Text before entity
......@@ -35,7 +36,9 @@ def insert_token(text, count, start_token, end_token, offset, length):
# Text after entity
+ text[count + 1 + offset + length :]
)
return text, count + 2
token_offset = len(start_token) + len(end_token)
return text, count + token_offset
def parse_tokens(filename):
......
......@@ -8,6 +8,7 @@ import sys
from datetime import date
from time import time
import mlflow
import numpy as np
import torch
import torch.distributed as dist
......@@ -21,6 +22,7 @@ from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dan.manager.metrics import MetricManager
from dan.mlflow import logging_metrics, logging_tags_metrics
from dan.ocr.utils import LM_ind_to_str
from dan.schedulers import DropoutScheduler
......@@ -401,7 +403,7 @@ class GenericTrainingManager:
"lr_schedulers"
][key]["class"](
self.optimizers[model_name],
**self.params["training_params"]["lr_schedulers"][key]["args"]
**self.params["training_params"]["lr_schedulers"][key]["args"],
)
# Load optimizer state from past training
......@@ -571,10 +573,11 @@ class GenericTrainingManager:
def zero_optimizer(self, model_name, set_to_none=True):
self.optimizers[model_name].zero_grad(set_to_none=set_to_none)
def train(self):
def train(self, mlflow_logging=False):
"""
Main training loop
"""
# init tensorboard file and output param summary file
if self.is_master:
self.writer = SummaryWriter(self.paths["results"])
......@@ -673,8 +676,13 @@ class GenericTrainingManager:
pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"]))
# log metrics in tensorboard file
# Log MLflow metrics
logging_metrics(
display_values, "train", num_epoch, mlflow_logging, self.is_master
)
if self.is_master:
# log metrics in tensorboard file
for key in display_values.keys():
self.writer.add_scalar(
"{}_{}".format(
......@@ -693,7 +701,9 @@ class GenericTrainingManager:
):
for valid_set_name in self.dataset.valid_loaders.keys():
# evaluate set and compute metrics
eval_values = self.evaluate(valid_set_name)
eval_values = self.evaluate(
valid_set_name, mlflow_logging=mlflow_logging
)
self.latest_valid_metrics = eval_values
# log valid metrics in tensorboard file
if self.is_master:
......@@ -742,7 +752,7 @@ class GenericTrainingManager:
self.save_model(epoch=num_epoch, name="weights", keep_weights=True)
self.writer.flush()
def evaluate(self, set_name, **kwargs):
def evaluate(self, set_name, mlflow_logging=False, **kwargs):
"""
Main loop for validation
"""
......@@ -781,11 +791,23 @@ class GenericTrainingManager:
pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"]))
# log metrics in MLflow
logging_metrics(
display_values,
"val",
self.latest_epoch,
mlflow_logging,
self.is_master,
)
if "cer_by_nb_cols" in metric_names:
self.log_cer_by_nb_cols(set_name)
return display_values
def predict(self, custom_name, sets_list, metric_names, output=False):
def predict(
self, custom_name, sets_list, metric_names, mlflow_logging=False, output=False
):
"""
Main loop for evaluation
"""
......@@ -828,6 +850,12 @@ class GenericTrainingManager:
pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"]))
# log metrics in MLflow
logging_name = custom_name.split("-")[1]
logging_tags_metrics(
display_values, logging_name, mlflow_logging, self.is_master
)
# output metrics values if requested
if output:
if "pred" in metric_names:
......@@ -841,6 +869,9 @@ class GenericTrainingManager:
for metric_name in metrics.keys():
f.write("{}: {}\n".format(metric_name, metrics[metric_name]))
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
def output_pred(self, name):
path = os.path.join(
self.paths["results"], "pred_{}_{}.txt".format(name, self.latest_epoch)
......
# -*- coding: utf-8 -*-
import os
from contextlib import contextmanager
import mlflow
from mlflow.exceptions import MlflowException
from dan import logger
def setup_environment(config: dict):
"""
Get the necessary variables from the config file and put them in the environment variables
:param config: dict, the config of the model
"""
needed_variables = {
"MLFLOW_S3_ENDPOINT_URL": "s3_endpoint_url",
"MLFLOW_TRACKING_URI": "tracking_uri",
"AWS_ACCESS_KEY_ID": "aws_access_key_id",
"AWS_SECRET_ACCESS_KEY": "aws_secret_access_key",
}
for variable_name, config_key in needed_variables.items():
if config_key in config:
os.environ[variable_name] = config[config_key]
def logging_metrics(
display_values: dict,
step: str,
epoch: int,
mlflow_logging: bool = False,
is_master: bool = False,
):
"""
Log dictionary metrics in the Metrics section of MLflow
:param display_values: dict, the dictionary containing the metrics to publish on MLflow
:param step: str, the step for which the metrics are to be published on Metrics section (ex: train, val, test). This will allow a better display on MLflow.
:param epoch: int, the current epoch.
:param mlflow_logging: bool, allows you to verify that you have the authorization to log on MLflow, defaults to False
:param is_master: bool, makes sure you're on the right thread, defaults to False
"""
if mlflow_logging and is_master:
mlflow_values = {
f"{step}_{name}": value for name, value in display_values.items()
}
mlflow.log_metrics(mlflow_values, epoch)
def logging_tags_metrics(
display_values: dict,
step: str,
mlflow_logging: bool = False,
is_master: bool = False,
):
"""
Log dictionary metrics in the Tags section of MLflow
:param display_values: dict, the dictionary containing the metrics to publish on MLflow
:param step: str, the step for which the metrics are to be published on Tags section (ex: train, val, test). This will allow a better display on MLflow.
:param mlflow_logging: bool, allows you to verify that you have the authorization to log on MLflow, defaults to False
:param is_master: bool, makes sure you're on the right thread, defaults to False
"""
if mlflow_logging and is_master:
mlflow_values = {
f"{step}_{name}": value for name, value in display_values.items()
}
mlflow.set_tags(mlflow_values)
@contextmanager
def start_mlflow_run(config: dict):
"""
Create an MLflow execution context with the parameters contained in the config file
:param config: dict, the config of the model
"""
# Set needed variables in environment
setup_environment(config)
# Set experiment from config
experiment_id = config.get("experiment_id")
assert experiment_id, "Missing MLflow experiment ID in the configuration"
try:
mlflow.set_experiment(experiment_id=experiment_id)
logger.info(f"Run Experiment ID : {experiment_id} on MLFlow")
except MlflowException as e:
logger.error(f"Couldn't set Mlflow experiment with ID: {experiment_id}")
raise e
# Start run
yield mlflow.start_run(run_name=config.get("run_name"))
mlflow.end_run()
# -*- coding: utf-8 -*-
import json
import logging
import random
from copy import deepcopy
from pathlib import Path
import mlflow
import numpy as np
import torch
import torch.multiprocessing as mp
......@@ -9,12 +14,15 @@ from torch.optim import Adam
from dan.decoder import GlobalHTADecoder
from dan.manager.ocr import OCRDataset, OCRDatasetManager
from dan.manager.training import Manager
from dan.mlflow import start_mlflow_run
from dan.models import FCN_Encoder
from dan.schedulers import exponential_dropout_scheduler, linear_scheduler
from dan.schedulers import exponential_dropout_scheduler
from dan.transforms import aug_config
logger = logging.getLogger(__name__)
def train_and_test(rank, params):
def train_and_test(rank, params, mlflow_logging=False):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
......@@ -26,7 +34,10 @@ def train_and_test(rank, params):
model = Manager(params)
model.load_model()
model.train()
if mlflow_logging:
logger.info("MLflow logging enabled")
model.train(mlflow_logging=mlflow_logging)
# load weights giving best CER on valid set
model.params["training_params"]["load_epoch"] = "best"
......@@ -42,21 +53,34 @@ def train_and_test(rank, params):
],
metrics,
output=True,
mlflow_logging=mlflow_logging,
)
def run():
def get_config():
"""
Retrieve model configuration
"""
dataset_name = "esposalles"
dataset_level = "record"
dataset_variant = ""
dataset_variant = "_debug"
dataset_path = ""
params = {
"mlflow": {
"dataset_name": dataset_name,
"run_name": "Test log DAN",
"s3_endpoint_url": "",
"tracking_uri": "",
"experiment_id": "9",
"aws_access_key_id": "",
"aws_secret_access_key": "",
},
"dataset_params": {
"dataset_manager": OCRDatasetManager,
"dataset_class": OCRDataset,
"datasets": {
dataset_name: "{}_{}{}".format(
dataset_name, dataset_level, dataset_variant
dataset_name: "{}/{}_{}{}".format(
dataset_path, dataset_name, dataset_level, dataset_variant
),
},
"train": {
......@@ -90,41 +114,7 @@ def run():
},
],
"augmentation": aug_config(0.9, 0.1),
# "synthetic_data": None,
"synthetic_data": {
"init_proba": 0.9, # begin proba to generate synthetic document
"end_proba": 0.2, # end proba to generate synthetic document
"num_steps_proba": 200000, # linearly decrease the percent of synthetic document from 90% to 20% through 200000 samples
"proba_scheduler_function": linear_scheduler, # decrease proba rate linearly
"start_scheduler_at_max_line": True, # start decreasing proba only after curriculum reach max number of lines
"dataset_level": dataset_level,
"curriculum": True, # use curriculum learning (slowly increase number of lines per synthetic samples)
"crop_curriculum": True, # during curriculum learning, crop images under the last text line
"curr_start": 0, # start curriculum at iteration
"curr_step": 10000, # interval to increase the number of lines for curriculum learning
"min_nb_lines": 1, # initial number of lines for curriculum learning
"max_nb_lines": 4, # maximum number of lines for curriculum learning
"padding_value": 255,
"font_path": "fonts/",
# config for synthetic line generation
"config": {
"background_color_default": (255, 255, 255),
"background_color_eps": 15,
"text_color_default": (0, 0, 0),
"text_color_eps": 15,
"font_size_min": 35,
"font_size_max": 45,
"color_mode": "RGB",
"padding_left_ratio_min": 0.00,
"padding_left_ratio_max": 0.05,
"padding_right_ratio_min": 0.02,
"padding_right_ratio_max": 0.2,
"padding_top_ratio_min": 0.02,
"padding_top_ratio_max": 0.1,
"padding_bottom_ratio_min": 0.02,
"padding_bottom_ratio_max": 0.1,
},
},
"synthetic_data": None,
},
},
"model_params": {
......@@ -175,7 +165,7 @@ def run():
},
"training_params": {
"output_folder": "dan_esposalles_record", # folder name for checkpoint and results
"max_nb_epochs": 50000, # maximum number of epochs before to stop
"max_nb_epochs": 710, # maximum number of epochs before to stop
"max_training_time": 3600
* 24
* 1.9, # maximum time before to stop (in seconds)
......@@ -228,12 +218,100 @@ def run():
},
}
if (
params["training_params"]["use_ddp"]
and not params["training_params"]["force_cpu"]
):
mp.spawn(
train_and_test, args=(params,), nprocs=params["training_params"]["nb_gpu"]
)
return params
def serialize_config(config):
"""
Serialize a dictionary to transform it into json and remove the credentials
"""
# Create a copy of the original config without erase it
serialized_config = deepcopy(config)
# Remove credentials to the config
serialized_config["mlflow"]["s3_endpoint_url"] = ""
serialized_config["mlflow"]["tracking_uri"] = ""
serialized_config["mlflow"]["aws_access_key_id"] = ""
serialized_config["mlflow"]["aws_secret_access_key"] = ""
# Get the name of the class
serialized_config["dataset_params"]["dataset_manager"] = serialized_config[
"dataset_params"
]["dataset_manager"].__name__
serialized_config["dataset_params"]["dataset_class"] = serialized_config[
"dataset_params"
]["dataset_class"].__name__
serialized_config["model_params"]["models"]["encoder"] = serialized_config[
"model_params"
]["models"]["encoder"].__name__
serialized_config["model_params"]["models"]["decoder"] = serialized_config[
"model_params"
]["models"]["decoder"].__name__
serialized_config["training_params"]["optimizers"]["all"][
"class"
] = serialized_config["training_params"]["optimizers"]["all"]["class"].__name__
# Cast the functions to str
serialized_config["dataset_params"]["config"]["augmentation"] = str(
serialized_config["dataset_params"]["config"]["augmentation"]
)
serialized_config["model_params"]["dropout_scheduler"]["function"] = str(
serialized_config["model_params"]["dropout_scheduler"]["function"]
)
serialized_config["training_params"]["nb_gpu"] = str(
serialized_config["training_params"]["nb_gpu"]
)
return serialized_config
def run():
"""
Main program, training a new model, using a valid configuration
"""
config = get_config()
config_artifact = serialize_config(config)
labels_artifact = ""
dataset_name = config["mlflow"]["dataset_name"]
labels_path = (
Path(config_artifact["dataset_params"]["datasets"][dataset_name])
/ "labels.json"
)
if config["mlflow"]:
with start_mlflow_run(config["mlflow"]) as run:
logger.info(f"Set tags to MLflow on {config['mlflow']['run_name']}")
mlflow.set_tags({"Dataset": config["mlflow"]["dataset_name"]})
# Get the labels json file
with open(labels_path) as json_file:
labels_artifact = json.load(json_file)
# Log MLflow artifacts
mlflow.log_dict(config_artifact, "config.json")
mlflow.log_dict(labels_artifact, "labels.json")
logger.info(f"Started MLflow run with ID ({run.info.run_id})")
if (
config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"]
):
mp.spawn(
train_and_test,
args=(config, True),
nprocs=config["training_params"]["nb_gpu"],
)
else:
train_and_test(0, config, True)
else:
train_and_test(0, params)
if (
config["training_params"]["use_ddp"]
and not config["training_params"]["force_cpu"]
):
mp.spawn(
train_and_test,
args=(config, True),
nprocs=config["training_params"]["nb_gpu"],
)
else:
train_and_test(0, config, True)
# MLflow
::: dan.mlflow
......@@ -94,6 +94,7 @@ nav:
- Utils: ref/ocr/line/utils.md
- Decoders: ref/decoder.md
- Models: ref/models.md
- MLflow: ref/mlflow.md
- Post Processing: ref/post_processing.md
- Inference: ref/predict.md
- Schedulers: ref/schedulers.md
......
arkindex-client==1.0.11
boto3==1.26.51
editdistance==0.6.1
fontTools==4.38.0
imageio==2.22.4
mlflow==2.0.1
networkx==2.8.8
numpy==1.23.5
opencv-python==4.6.0.66
......