diff --git a/Datasets/dataset_formatters/generic_dataset_formatter.py b/Datasets/dataset_formatters/generic_dataset_formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..ae55af026c841af4a9ab8ee08e3437bf9d16c034 --- /dev/null +++ b/Datasets/dataset_formatters/generic_dataset_formatter.py @@ -0,0 +1,89 @@ +import os +import shutil +import tarfile +import pickle +import re +from PIL import Image +import numpy as np + + +class DatasetFormatter: + """ + Global pipeline/functions for dataset formatting + """ + + def __init__(self, dataset_name, level, extra_name="", set_names=["train", "valid", "test"]): + self.dataset_name = dataset_name + self.level = level + self.set_names = set_names + self.target_fold_path = os.path.join( + "Datasets", "formatted", "{}_{}{}".format(dataset_name, level, extra_name)) + self.map_datasets_files = dict() + self.extract_with_dirname = False + + def format(self): + self.init_format() + self.map_datasets_files[self.dataset_name][self.level]["format_function"]() + self.end_format() + + def init_format(self): + """ + Load and extracts needed files + """ + os.makedirs(self.target_fold_path, exist_ok=True) + + for set_name in self.set_names: + os.makedirs(os.path.join(self.target_fold_path, set_name), exist_ok=True) + + +class OCRDatasetFormatter(DatasetFormatter): + """ + Specific pipeline/functions for OCR/HTR dataset formatting + """ + + def __init__(self, source_dataset, level, extra_name="", set_names=["train", "valid", "test"]): + super(OCRDatasetFormatter, self).__init__(source_dataset, level, extra_name, set_names) + self.charset = set() + self.gt = dict() + for set_name in set_names: + self.gt[set_name] = dict() + + def format_text_label(self, label): + """ + Remove extra space or line break characters + """ + temp = re.sub("(\n)+", '\n', label) + return re.sub("( )+", ' ', temp).strip(" \n") + + def load_resize_save(self, source_path, target_path): + """ + Load image, apply resolution modification and save it + """ + shutil.copyfile(source_path, target_path) + + def resize(self, img, source_dpi, target_dpi): + """ + Apply resolution modification to image + """ + if source_dpi == target_dpi: + return img + if isinstance(img, np.ndarray): + h, w = img.shape[:2] + img = Image.fromarray(img) + else: + w, h = img.size + ratio = target_dpi / source_dpi + img = img.resize((int(w*ratio), int(h*ratio)), Image.BILINEAR) + return np.array(img) + + def end_format(self): + """ + Save label and charset files + """ + with open(os.path.join(self.target_fold_path, "labels.pkl"), "wb") as f: + pickle.dump({ + "ground_truth": self.gt, + "charset": sorted(list(self.charset)), + }, f) + with open(os.path.join(self.target_fold_path, "charset.pkl"), "wb") as f: + pickle.dump(sorted(list(self.charset)), f) diff --git a/Datasets/dataset_formatters/simara_formatter.py b/Datasets/dataset_formatters/simara_formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..c60eb30f94cf569d3b2d9d9197a16756ee994c01 --- /dev/null +++ b/Datasets/dataset_formatters/simara_formatter.py @@ -0,0 +1,104 @@ +from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter +import os +import numpy as np +from Datasets.dataset_formatters.utils_dataset import natural_sort +from PIL import Image +import xml.etree.ElementTree as ET +import re +from tqdm import tqdm + +# Layout string to token +SEM_MATCHING_TOKENS_STR = { + "INTITULE": "ⓘ", + "DATE": "â““", + "COTE_SERIE": "â“¢", + "ANALYSE_COMPL": "â“’", + "PRECISIONS_SUR_COTE": "ⓟ", + "COTE_ARTICLE": "â“" +} + +# Layout begin-token to end-token +SEM_MATCHING_TOKENS = { + "ⓘ": "â’¾", + "â““": "â’¹", + "â“¢": "Ⓢ", + "â“’": "â’¸", + "ⓟ": "â“…", + "â“": "â’¶" +} + +class SimaraDatasetFormatter(OCRDatasetFormatter): + def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=True): + super(SimaraDatasetFormatter, self).__init__("simara", level, "_sem" if sem_token else "", set_names) + + self.dpi = dpi + self.sem_token = sem_token + self.map_datasets_files.update({ + "simara": { + # (1,050 for train, 100 for validation and 100 for test) + "page": { + "format_function": self.format_simara_page, + }, + } + }) + self.matching_tokens_str = SEM_MATCHING_TOKENS_STR + self.matching_tokens = SEM_MATCHING_TOKENS + + def preformat_simara_page(self): + """ + Extract all information from dataset and correct some annotations + """ + dataset = { + "train": list(), + "valid": list(), + "test": list() + } + img_folder_path = os.path.join("Datasets", "raw", "simara", "images") + labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels") + sem_labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels_sem") + train_files = [ + os.path.join(labels_folder_path, 'train', name) + for name in os.listdir(os.path.join(sem_labels_folder_path, 'train'))] + valid_files = [ + os.path.join(labels_folder_path, 'valid', name) + for name in os.listdir(os.path.join(sem_labels_folder_path, 'valid'))] + test_files = [ + os.path.join(labels_folder_path, 'test', name) + for name in os.listdir(os.path.join(sem_labels_folder_path, 'test'))] + for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]): + for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)): + with open(label_file, 'r') as f: + text = f.read() + with open(label_file.replace('labels', 'labels_sem'), 'r') as f: + sem_text = f.read() + dataset[set_name].append({ + "img_path": os.path.join( + img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'jpg')), + "label": text, + "sem_label": sem_text, + }) + print(dataset['test'], len(dataset['test'])) + return dataset + + def format_simara_page(self): + """ + Format simara page dataset + """ + dataset = self.preformat_simara_page() + for set_name in self.set_names: + fold = os.path.join(self.target_fold_path, set_name) + for sample in tqdm(dataset[set_name], desc='Formatting '+set_name): + new_name = sample['img_path'].split('/')[-1] + new_img_path = os.path.join(fold, new_name) + self.load_resize_save(sample["img_path"], new_img_path)#, 300, self.dpi) + page = { + "text": sample["label"] if not self.sem_token else sample["sem_label"], + } + self.charset = self.charset.union(set(page["text"])) + self.gt[set_name][new_name] = page + + +if __name__ == "__main__": + + SimaraDatasetFormatter("page", sem_token=True).format() + #SimaraDatasetFormatter("page", sem_token=False).format() diff --git a/Datasets/dataset_formatters/utils_dataset.py b/Datasets/dataset_formatters/utils_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d980dadb926ceee5f4d13e28620e21a11a1177 --- /dev/null +++ b/Datasets/dataset_formatters/utils_dataset.py @@ -0,0 +1,35 @@ +import re +import random +import cv2 +import json +random.seed(42) + + +def natural_sort(l): + convert = lambda text: int(text) if text.isdigit() else text.lower() + alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key)] + return sorted(l, key=alphanum_key) + + +def assign_random_split(train_prob, val_prob): + """ + assuming train_prob + val_prob + test_prob = 1 + """ + prob = random.random() + if prob <= train_prob: + return "train" + elif prob <= train_prob + val_prob: + return "val" + else: + return "test" + +def save_text(path, text): + with open(path, 'w') as f: + f.write(text) + +def save_image(path, image): + cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + +def save_json(path, dict): + with open(path, "w") as outfile: + json.dump(dict, outfile, indent=4) \ No newline at end of file diff --git a/OCR/line_OCR/ctc/models_line_ctc.py b/OCR/line_OCR/ctc/models_line_ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..c13c072ebd10e810d41b8aab1e318c2170637ea0 --- /dev/null +++ b/OCR/line_OCR/ctc/models_line_ctc.py @@ -0,0 +1,19 @@ + +from torch.nn.functional import log_softmax +from torch.nn import AdaptiveMaxPool2d, Conv1d +from torch.nn import Module + + +class Decoder(Module): + def __init__(self, params): + super(Decoder, self).__init__() + + self.vocab_size = params["vocab_size"] + + self.ada_pool = AdaptiveMaxPool2d((1, None)) + self.end_conv = Conv1d(in_channels=params["enc_size"], out_channels=self.vocab_size+1, kernel_size=1) + + def forward(self, x): + x = self.ada_pool(x).squeeze(2) + x = self.end_conv(x) + return log_softmax(x, dim=1) diff --git a/OCR/ocr_manager.py b/OCR/ocr_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..e34b36c0a0793868c5a8f590b20510db72bc0681 --- /dev/null +++ b/OCR/ocr_manager.py @@ -0,0 +1,67 @@ +from basic.generic_training_manager import GenericTrainingManager +import os +from PIL import Image +import pickle + + +class OCRManager(GenericTrainingManager): + def __init__(self, params): + super(OCRManager, self).__init__(params) + self.params["model_params"]["vocab_size"] = len(self.dataset.charset) + + def generate_syn_line_dataset(self, name): + """ + Generate synthetic line dataset from currently loaded dataset + """ + dataset_name = list(self.params['dataset_params']["datasets"].keys())[0] + path = os.path.join(os.path.dirname(self.params['dataset_params']["datasets"][dataset_name]), name) + os.makedirs(path, exist_ok=True) + charset = set() + dataset = None + gt = { + "train": dict(), + "valid": dict(), + "test": dict() + } + for set_name in ["train", "valid", "test"]: + set_path = os.path.join(path, set_name) + os.makedirs(set_path, exist_ok=True) + if set_name == "train": + dataset = self.dataset.train_dataset + elif set_name == "valid": + dataset = self.dataset.valid_datasets["{}-valid".format(dataset_name)] + elif set_name == "test": + self.dataset.generate_test_loader("{}-test".format(dataset_name), [(dataset_name, "test"), ]) + dataset = self.dataset.test_datasets["{}-test".format(dataset_name)] + + samples = list() + for sample in dataset.samples: + for line_label in sample["label"].split("\n"): + for chunk in [line_label[i:i+100] for i in range(0, len(line_label), 100)]: + charset = charset.union(set(chunk)) + if len(chunk) > 0: + samples.append({ + "path": sample["path"], + "label": chunk, + "nb_cols": 1, + }) + + for i, sample in enumerate(samples): + ext = sample['path'].split(".")[-1] + img_name = "{}_{}.{}".format(set_name, i, ext) + img_path = os.path.join(set_path, img_name) + + img = dataset.generate_typed_text_line_image(sample["label"]) + Image.fromarray(img).save(img_path) + gt[set_name][img_name] = { + "text": sample["label"], + "nb_cols": sample["nb_cols"] if "nb_cols" in sample else 1 + } + if "line_label" in sample: + gt[set_name][img_name]["lines"] = sample["line_label"] + + with open(os.path.join(path, "labels.pkl"), "wb") as f: + pickle.dump({ + "ground_truth": gt, + "charset": sorted(list(charset)), + }, f) \ No newline at end of file diff --git a/basic/generic_training_manager.py b/basic/generic_training_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f41b28274678b23083123e33d2343bb6b942b9 --- /dev/null +++ b/basic/generic_training_manager.py @@ -0,0 +1,706 @@ +import torch +import os +import sys +import copy +import json +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from torch.nn.init import kaiming_uniform_ +from tqdm import tqdm +from time import time +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.cuda.amp import GradScaler +from basic.metric_manager import MetricManager +from basic.scheduler import DropoutScheduler +from datetime import date + + +class GenericTrainingManager: + + def __init__(self, params): + self.type = None + self.is_master = False + self.params = params + self.dropout_scheduler = None + self.models = {} + self.begin_time = None + self.dataset = None + self.dataset_name = list(self.params["dataset_params"]["datasets"].values())[0] + self.paths = None + self.latest_step = 0 + self.latest_epoch = -1 + self.latest_batch = 0 + self.total_batch = 0 + self.grad_acc_step = 0 + self.latest_train_metrics = dict() + self.latest_valid_metrics = dict() + self.curriculum_info = dict() + self.curriculum_info["latest_valid_metrics"] = dict() + self.phase = None + self.max_mem_usage_by_epoch = list() + self.losses = list() + self.lr_values = list() + + self.scaler = None + + self.optimizers = dict() + self.optimizers_named_params_by_group = dict() + self.lr_schedulers = dict() + self.best = None + self.writer = None + self.metric_manager = dict() + + self.init_hardware_config() + self.init_paths() + self.load_dataset() + self.params["model_params"]["use_amp"] = self.params["training_params"]["use_amp"] + + def init_paths(self): + """ + Create output folders for results and checkpoints + """ + output_path = os.path.join("outputs", self.params["training_params"]["output_folder"]) + os.makedirs(output_path, exist_ok=True) + checkpoints_path = os.path.join(output_path, "checkpoints") + os.makedirs(checkpoints_path, exist_ok=True) + results_path = os.path.join(output_path, "results") + os.makedirs(results_path, exist_ok=True) + + self.paths = { + "results": results_path, + "checkpoints": checkpoints_path, + "output_folder": output_path + } + + def load_dataset(self): + """ + Load datasets, data samplers and data loaders + """ + self.params["dataset_params"]["use_ddp"] = self.params["training_params"]["use_ddp"] + self.params["dataset_params"]["batch_size"] = self.params["training_params"]["batch_size"] + if "valid_batch_size" in self.params["training_params"]: + self.params["dataset_params"]["valid_batch_size"] = self.params["training_params"]["valid_batch_size"] + if "test_batch_size" in self.params["training_params"]: + self.params["dataset_params"]["test_batch_size"] = self.params["training_params"]["test_batch_size"] + self.params["dataset_params"]["num_gpu"] = self.params["training_params"]["nb_gpu"] + self.params["dataset_params"]["worker_per_gpu"] = 4 if "worker_per_gpu" not in self.params["dataset_params"] else self.params["dataset_params"]["worker_per_gpu"] + self.dataset = self.params["dataset_params"]["dataset_manager"](self.params["dataset_params"]) + self.dataset.load_datasets() + self.dataset.load_ddp_samplers() + self.dataset.load_dataloaders() + + def init_hardware_config(self): + # Debug mode + if self.params["training_params"]["force_cpu"]: + self.params["training_params"]["use_ddp"] = False + self.params["training_params"]["use_amp"] = False + # Manage Distributed Data Parallel & GPU usage + self.manual_seed = 1111 if "manual_seed" not in self.params["training_params"].keys() else \ + self.params["training_params"]["manual_seed"] + self.ddp_config = { + "master": self.params["training_params"]["use_ddp"] and self.params["training_params"]["ddp_rank"] == 0, + "address": "localhost" if "ddp_addr" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_addr"], + "port": "11111" if "ddp_port" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_port"], + "backend": "nccl" if "ddp_backend" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_backend"], + "rank": self.params["training_params"]["ddp_rank"], + } + self.is_master = self.ddp_config["master"] or not self.params["training_params"]["use_ddp"] + if self.params["training_params"]["force_cpu"]: + self.device = "cpu" + else: + if self.params["training_params"]["use_ddp"]: + self.device = torch.device(self.ddp_config["rank"]) + self.params["dataset_params"]["ddp_rank"] = self.ddp_config["rank"] + self.launch_ddp() + else: + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.params["model_params"]["device"] = self.device.type + # Print GPU info + # global + if (self.params["training_params"]["use_ddp"] and self.ddp_config["master"]) or not self.params["training_params"]["use_ddp"]: + print("##################") + print("Available GPUS: {}".format(self.params["training_params"]["nb_gpu"])) + for i in range(self.params["training_params"]["nb_gpu"]): + print("Rank {}: {} {}".format(i, torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i))) + print("##################") + # local + print("Local GPU:") + if self.device != "cpu": + print("Rank {}: {} {}".format(self.params["training_params"]["ddp_rank"], torch.cuda.get_device_name(), torch.cuda.get_device_properties(self.device))) + else: + print("WORKING ON CPU !\n") + print("##################") + + def load_model(self, reset_optimizer=False, strict=True): + """ + Load model weights from scratch or from checkpoints + """ + # Instantiate Model + for model_name in self.params["model_params"]["models"].keys(): + self.models[model_name] = self.params["model_params"]["models"][model_name](self.params["model_params"]) + self.models[model_name].to(self.device) # To GPU or CPU + # make the model compatible with Distributed Data Parallel if used + if self.params["training_params"]["use_ddp"]: + self.models[model_name] = DDP(self.models[model_name], [self.ddp_config["rank"]]) + + # Handle curriculum dropout + if "dropout_scheduler" in self.params["model_params"]: + func = self.params["model_params"]["dropout_scheduler"]["function"] + T = self.params["model_params"]["dropout_scheduler"]["T"] + self.dropout_scheduler = DropoutScheduler(self.models, func, T) + + self.scaler = GradScaler(enabled=self.params["training_params"]["use_amp"]) + + # Check if checkpoint exists + checkpoint = self.get_checkpoint() + if checkpoint is not None: + self.load_existing_model(checkpoint, strict=strict) + else: + self.init_new_model() + + self.load_optimizers(checkpoint, reset_optimizer=reset_optimizer) + + if self.is_master: + print("LOADED EPOCH: {}\n".format(self.latest_epoch), flush=True) + + def get_checkpoint(self): + """ + Seek if checkpoint exist, return None otherwise + """ + if self.params["training_params"]["load_epoch"] in ("best", "last"): + for filename in os.listdir(self.paths["checkpoints"]): + if self.params["training_params"]["load_epoch"] in filename: + return torch.load(os.path.join(self.paths["checkpoints"], filename)) + return None + + def load_existing_model(self, checkpoint, strict=True): + """ + Load information and weights from previous training + """ + self.load_save_info(checkpoint) + self.latest_epoch = checkpoint["epoch"] + if "step" in checkpoint: + self.latest_step = checkpoint["step"] + self.best = checkpoint["best"] + if "scaler_state_dict" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) + # Load model weights from past training + for model_name in self.models.keys(): + self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(model_name)], strict=strict) + + def init_new_model(self): + """ + Initialize model + """ + # Specific weights initialization if exists + for model_name in self.models.keys(): + try: + self.models[model_name].init_weights() + except: + pass + + # Handle transfer learning instructions + if self.params["model_params"]["transfer_learning"]: + # Iterates over models + for model_name in self.params["model_params"]["transfer_learning"].keys(): + state_dict_name, path, learnable, strict = self.params["model_params"]["transfer_learning"][model_name] + # Loading pretrained weights file + checkpoint = torch.load(path) + try: + # Load pretrained weights for model + self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(state_dict_name)], strict=strict) + print("transfered weights for {}".format(state_dict_name), flush=True) + except RuntimeError as e: + print(e, flush=True) + # if error, try to load each parts of the model (useful if only few layers are different) + for key in checkpoint["{}_state_dict".format(state_dict_name)].keys(): + try: + # for pre-training of decision layer + if "end_conv" in key and "transfered_charset" in self.params["model_params"]: + self.adapt_decision_layer_to_old_charset(model_name, key, checkpoint, state_dict_name) + else: + self.models[model_name].load_state_dict( + {key: checkpoint["{}_state_dict".format(state_dict_name)][key]}, strict=False) + except RuntimeError as e: + ## exception when adding linebreak token from pretraining + print(e, flush=True) + # Set parameters no trainable + if not learnable: + self.set_model_learnable(self.models[model_name], False) + + def adapt_decision_layer_to_old_charset(self, model_name, key, checkpoint, state_dict_name): + """ + Transfer learning of the decision learning in case of close charsets between pre-training and training + """ + pretrained_chars = list() + weights = checkpoint["{}_state_dict".format(state_dict_name)][key] + new_size = list(weights.size()) + new_size[0] = len(self.dataset.charset) + self.params["model_params"]["additional_tokens"] + new_weights = torch.zeros(new_size, device=weights.device, dtype=weights.dtype) + old_charset = checkpoint["charset"] if "charset" in checkpoint else self.params["model_params"]["old_charset"] + if not "bias" in key: + kaiming_uniform_(new_weights, nonlinearity="relu") + for i, c in enumerate(self.dataset.charset): + if c in old_charset: + new_weights[i] = weights[old_charset.index(c)] + pretrained_chars.append(c) + if "transfered_charset_last_is_ctc_blank" in self.params["model_params"] and self.params["model_params"]["transfered_charset_last_is_ctc_blank"]: + new_weights[-1] = weights[-1] + pretrained_chars.append("<blank>") + checkpoint["{}_state_dict".format(state_dict_name)][key] = new_weights + self.models[model_name].load_state_dict({key: checkpoint["{}_state_dict".format(state_dict_name)][key]}, strict=False) + print("Pretrained chars for {} ({}): {}".format(key, len(pretrained_chars), pretrained_chars)) + + def load_optimizers(self, checkpoint, reset_optimizer=False): + """ + Load the optimizer of each model + """ + for model_name in self.models.keys(): + new_params = dict() + if checkpoint and "optimizer_named_params_{}".format(model_name) in checkpoint: + self.optimizers_named_params_by_group[model_name] = checkpoint["optimizer_named_params_{}".format(model_name)] + # for progressively growing models + for name, param in self.models[model_name].named_parameters(): + existing = False + for gr in self.optimizers_named_params_by_group[model_name]: + if name in gr: + gr[name] = param + existing = True + break + if not existing: + new_params.update({name: param}) + else: + self.optimizers_named_params_by_group[model_name] = [dict(), ] + self.optimizers_named_params_by_group[model_name][0].update(self.models[model_name].named_parameters()) + + # Instantiate optimizer + self.reset_optimizer(model_name) + + # Handle learning rate schedulers + if "lr_schedulers" in self.params["training_params"] and self.params["training_params"]["lr_schedulers"]: + key = "all" if "all" in self.params["training_params"]["lr_schedulers"] else model_name + if key in self.params["training_params"]["lr_schedulers"]: + self.lr_schedulers[model_name] = self.params["training_params"]["lr_schedulers"][key]["class"]\ + (self.optimizers[model_name], **self.params["training_params"]["lr_schedulers"][key]["args"]) + + # Load optimizer state from past training + if checkpoint and not reset_optimizer: + self.optimizers[model_name].load_state_dict(checkpoint["optimizer_{}_state_dict".format(model_name)]) + # Load optimizer scheduler config from past training if used + if "lr_schedulers" in self.params["training_params"] and self.params["training_params"]["lr_schedulers"] \ + and "lr_scheduler_{}_state_dict".format(model_name) in checkpoint.keys(): + self.lr_schedulers[model_name].load_state_dict(checkpoint["lr_scheduler_{}_state_dict".format(model_name)]) + + # for progressively growing models, keeping learning rate + if checkpoint and new_params: + self.optimizers_named_params_by_group[model_name].append(new_params) + self.optimizers[model_name].add_param_group({"params": list(new_params.values())}) + + @staticmethod + def set_model_learnable(model, learnable=True): + for p in list(model.parameters()): + p.requires_grad = learnable + + def save_model(self, epoch, name, keep_weights=False): + """ + Save model weights and training info for curriculum learning or learning rate for instance + """ + if not self.is_master: + return + to_del = [] + for filename in os.listdir(self.paths["checkpoints"]): + if name in filename: + to_del.append(os.path.join(self.paths["checkpoints"], filename)) + path = os.path.join(self.paths["checkpoints"], "{}_{}.pt".format(name, epoch)) + content = { + 'optimizers_named_params': self.optimizers_named_params_by_group, + 'epoch': epoch, + 'step': self.latest_step, + "scaler_state_dict": self.scaler.state_dict(), + 'best': self.best, + "charset": self.dataset.charset + } + for model_name in self.optimizers: + content['optimizer_{}_state_dict'.format(model_name)] = self.optimizers[model_name].state_dict() + for model_name in self.lr_schedulers: + content["lr_scheduler_{}_state_dict".format(model_name)] = self.lr_schedulers[model_name].state_dict() + content = self.add_save_info(content) + for model_name in self.models.keys(): + content["{}_state_dict".format(model_name)] = self.models[model_name].state_dict() + torch.save(content, path) + if not keep_weights: + for path_to_del in to_del: + if path_to_del != path: + os.remove(path_to_del) + + def reset_optimizers(self): + """ + Reset learning rate of all optimizers + """ + for model_name in self.models.keys(): + self.reset_optimizer(model_name) + + def reset_optimizer(self, model_name): + """ + Reset optimizer learning rate for given model + """ + params = list(self.optimizers_named_params_by_group[model_name][0].values()) + key = "all" if "all" in self.params["training_params"]["optimizers"] else model_name + self.optimizers[model_name] = self.params["training_params"]["optimizers"][key]["class"](params, **self.params["training_params"]["optimizers"][key]["args"]) + for i in range(1, len(self.optimizers_named_params_by_group[model_name])): + self.optimizers[model_name].add_param_group({"params": list(self.optimizers_named_params_by_group[model_name][i].values())}) + + def save_params(self): + """ + Output text file containing a summary of all hyperparameters chosen for the training + """ + def compute_nb_params(module): + return sum([np.prod(p.size()) for p in list(module.parameters())]) + + def class_to_str_dict(my_dict): + for key in my_dict.keys(): + if callable(my_dict[key]): + my_dict[key] = my_dict[key].__name__ + elif isinstance(my_dict[key], np.ndarray): + my_dict[key] = my_dict[key].tolist() + elif isinstance(my_dict[key], dict): + my_dict[key] = class_to_str_dict(my_dict[key]) + return my_dict + + path = os.path.join(self.paths["results"], "params") + if os.path.isfile(path): + return + params = copy.deepcopy(self.params) + params = class_to_str_dict(params) + params["date"] = date.today().strftime("%d/%m/%Y") + total_params = 0 + for model_name in self.models.keys(): + current_params = compute_nb_params(self.models[model_name]) + params["model_params"]["models"][model_name] = [params["model_params"]["models"][model_name], "{:,}".format(current_params)] + total_params += current_params + params["model_params"]["total_params"] = "{:,}".format(total_params) + + params["hardware"] = dict() + if self.device != "cpu": + for i in range(self.params["training_params"]["nb_gpu"]): + params["hardware"][str(i)] = "{} {}".format(torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i)) + else: + params["hardware"]["0"] = "CPU" + params["software"] = { + "python_version": sys.version, + "pytorch_version": torch.__version__, + "cuda_version": torch.version.cuda, + "cudnn_version": torch.backends.cudnn.version(), + } + with open(path, 'w') as f: + json.dump(params, f, indent=4) + + def backward_loss(self, loss, retain_graph=False): + self.scaler.scale(loss).backward(retain_graph=retain_graph) + + def step_optimizers(self, increment_step=True, names=None): + for model_name in self.optimizers: + if names and model_name not in names: + continue + if "gradient_clipping" in self.params["training_params"] and model_name in self.params["training_params"]["gradient_clipping"]["models"]: + self.scaler.unscale_(self.optimizers[model_name]) + torch.nn.utils.clip_grad_norm_(self.models[model_name].parameters(), self.params["training_params"]["gradient_clipping"]["max"]) + self.scaler.step(self.optimizers[model_name]) + self.scaler.update() + self.latest_step += 1 + + def zero_optimizers(self, set_to_none=True): + for model_name in self.optimizers: + self.zero_optimizer(model_name, set_to_none) + + def zero_optimizer(self, model_name, set_to_none=True): + self.optimizers[model_name].zero_grad(set_to_none=set_to_none) + + def train(self): + """ + Main training loop + """ + # init tensorboard file and output param summary file + if self.is_master: + self.writer = SummaryWriter(self.paths["results"]) + self.save_params() + # init variables + self.begin_time = time() + focus_metric_name = self.params["training_params"]["focus_metric"] + nb_epochs = self.params["training_params"]["max_nb_epochs"] + interval_save_weights = self.params["training_params"]["interval_save_weights"] + metric_names = self.params["training_params"]["train_metrics"] + + display_values = None + # init curriculum learning + if "curriculum_learning" in self.params["training_params"].keys() and self.params["training_params"]["curriculum_learning"]: + self.init_curriculum() + # perform epochs + for num_epoch in range(self.latest_epoch+1, nb_epochs): + self.dataset.train_dataset.training_info = { + "epoch": self.latest_epoch, + "step": self.latest_step + } + self.phase = "train" + # Check maximum training time stop condition + if self.params["training_params"]["max_training_time"] and time() - self.begin_time > self.params["training_params"]["max_training_time"]: + break + # set models trainable + for model_name in self.models.keys(): + self.models[model_name].train() + self.latest_epoch = num_epoch + if self.dataset.train_dataset.curriculum_config: + self.dataset.train_dataset.curriculum_config["epoch"] = self.latest_epoch + # init epoch metrics values + self.metric_manager["train"] = MetricManager(metric_names=metric_names, dataset_name=self.dataset_name) + + with tqdm(total=len(self.dataset.train_loader.dataset)) as pbar: + pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs)) + # iterates over mini-batch data + for ind_batch, batch_data in enumerate(self.dataset.train_loader): + self.latest_batch = ind_batch + 1 + self.total_batch += 1 + # train on batch data and compute metrics + batch_values = self.train_batch(batch_data, metric_names) + batch_metrics = self.metric_manager["train"].compute_metrics(batch_values, metric_names) + batch_metrics["names"] = batch_data["names"] + batch_metrics["ids"] = batch_data["ids"] + # Merge metrics if Distributed Data Parallel is used + if self.params["training_params"]["use_ddp"]: + batch_metrics = self.merge_ddp_metrics(batch_metrics) + # Update learning rate via scheduler if one is used + if self.params["training_params"]["lr_schedulers"]: + for model_name in self.models: + key = "all" if "all" in self.params["training_params"]["lr_schedulers"] else model_name + if model_name in self.lr_schedulers and ind_batch % self.params["training_params"]["lr_schedulers"][key]["step_interval"] == 0: + self.lr_schedulers[model_name].step(len(batch_metrics["names"])) + if "lr" in metric_names: + self.writer.add_scalar("lr_{}".format(model_name), self.lr_schedulers[model_name].lr, self.lr_schedulers[model_name].step_num) + # Update dropout scheduler if used + if self.dropout_scheduler: + self.dropout_scheduler.step(len(batch_metrics["names"])) + self.dropout_scheduler.update_dropout_rate() + + # Add batch metrics values to epoch metrics values + self.metric_manager["train"].update_metrics(batch_metrics) + display_values = self.metric_manager["train"].get_display_values() + pbar.set_postfix(values=str(display_values)) + pbar.update(len(batch_data["names"])) + + # log metrics in tensorboard file + if self.is_master: + for key in display_values.keys(): + self.writer.add_scalar('{}_{}'.format(self.params["dataset_params"]["train"]["name"], key), display_values[key], num_epoch) + self.latest_train_metrics = display_values + + # evaluate and compute metrics for valid sets + if self.params["training_params"]["eval_on_valid"] and num_epoch % self.params["training_params"]["eval_on_valid_interval"] == 0: + for valid_set_name in self.dataset.valid_loaders.keys(): + # evaluate set and compute metrics + eval_values = self.evaluate(valid_set_name) + self.latest_valid_metrics = eval_values + # log valid metrics in tensorboard file + if self.is_master: + for key in eval_values.keys(): + self.writer.add_scalar('{}_{}'.format(valid_set_name, key), eval_values[key], num_epoch) + if valid_set_name == self.params["training_params"]["set_name_focus_metric"] and (self.best is None or \ + (eval_values[focus_metric_name] <= self.best and self.params["training_params"]["expected_metric_value"] == "low") or\ + (eval_values[focus_metric_name] >= self.best and self.params["training_params"]["expected_metric_value"] == "high")): + self.save_model(epoch=num_epoch, name="best") + self.best = eval_values[focus_metric_name] + + # Handle curriculum learning update + if self.dataset.train_dataset.curriculum_config: + self.check_and_update_curriculum() + + if "curriculum_model" in self.params["model_params"] and self.params["model_params"]["curriculum_model"]: + self.update_curriculum_model() + + # save model weights + if self.is_master: + self.save_model(epoch=num_epoch, name="last") + if interval_save_weights and num_epoch % interval_save_weights == 0: + self.save_model(epoch=num_epoch, name="weigths", keep_weights=True) + self.writer.flush() + + def evaluate(self, set_name, **kwargs): + """ + Main loop for validation + """ + self.phase = "eval" + loader = self.dataset.valid_loaders[set_name] + # Set models in eval mode + for model_name in self.models.keys(): + self.models[model_name].eval() + metric_names = self.params["training_params"]["eval_metrics"] + display_values = None + + # initialize epoch metrics + self.metric_manager[set_name] = MetricManager(metric_names, dataset_name=self.dataset_name) + with tqdm(total=len(loader.dataset)) as pbar: + pbar.set_description("Evaluation E{}".format(self.latest_epoch)) + with torch.no_grad(): + # iterate over batch data + for ind_batch, batch_data in enumerate(loader): + self.latest_batch = ind_batch + 1 + # eval batch data and compute metrics + batch_values = self.evaluate_batch(batch_data, metric_names) + batch_metrics = self.metric_manager[set_name].compute_metrics(batch_values, metric_names) + batch_metrics["names"] = batch_data["names"] + batch_metrics["ids"] = batch_data["ids"] + # merge metrics values if Distributed Data Parallel is used + if self.params["training_params"]["use_ddp"]: + batch_metrics = self.merge_ddp_metrics(batch_metrics) + + # add batch metrics to epoch metrics + self.metric_manager[set_name].update_metrics(batch_metrics) + display_values = self.metric_manager[set_name].get_display_values() + + pbar.set_postfix(values=str(display_values)) + pbar.update(len(batch_data["names"])) + if "cer_by_nb_cols" in metric_names: + self.log_cer_by_nb_cols(set_name) + return display_values + + def predict(self, custom_name, sets_list, metric_names, output=False): + """ + Main loop for evaluation + """ + self.phase = "predict" + metric_names = metric_names.copy() + self.dataset.generate_test_loader(custom_name, sets_list) + loader = self.dataset.test_loaders[custom_name] + # Set models in eval mode + for model_name in self.models.keys(): + self.models[model_name].eval() + + # initialize epoch metrics + self.metric_manager[custom_name] = MetricManager(metric_names, self.dataset_name) + + with tqdm(total=len(loader.dataset)) as pbar: + pbar.set_description("Prediction") + with torch.no_grad(): + for ind_batch, batch_data in enumerate(loader): + # iterates over batch data + self.latest_batch = ind_batch + 1 + # eval batch data and compute metrics + batch_values = self.evaluate_batch(batch_data, metric_names) + batch_metrics = self.metric_manager[custom_name].compute_metrics(batch_values, metric_names) + batch_metrics["names"] = batch_data["names"] + batch_metrics["ids"] = batch_data["ids"] + # merge batch metrics if Distributed Data Parallel is used + if self.params["training_params"]["use_ddp"]: + batch_metrics = self.merge_ddp_metrics(batch_metrics) + + # add batch metrics to epoch metrics + self.metric_manager[custom_name].update_metrics(batch_metrics) + display_values = self.metric_manager[custom_name].get_display_values() + + pbar.set_postfix(values=str(display_values)) + pbar.update(len(batch_data["names"])) + + self.dataset.remove_test_dataset(custom_name) + # output metrics values if requested + if output: + if "pred" in metric_names: + self.output_pred(custom_name) + metrics = self.metric_manager[custom_name].get_display_values(output=True) + path = os.path.join(self.paths["results"], "predict_{}_{}.txt".format(custom_name, self.latest_epoch)) + with open(path, "w") as f: + for metric_name in metrics.keys(): + f.write("{}: {}\n".format(metric_name, metrics[metric_name])) + + def output_pred(self, name): + path = os.path.join(self.paths["results"], "pred_{}_{}.txt".format(name, self.latest_epoch)) + pred = "\n".join(self.metric_manager[name].get("pred")) + with open(path, "w") as f: + f.write(pred) + + def launch_ddp(self): + """ + Initialize Distributed Data Parallel system + """ + mp.set_start_method('fork', force=True) + os.environ['MASTER_ADDR'] = self.ddp_config["address"] + os.environ['MASTER_PORT'] = str(self.ddp_config["port"]) + dist.init_process_group(self.ddp_config["backend"], rank=self.ddp_config["rank"], world_size=self.params["training_params"]["nb_gpu"]) + torch.cuda.set_device(self.ddp_config["rank"]) + random.seed(self.manual_seed) + np.random.seed(self.manual_seed) + torch.manual_seed(self.manual_seed) + torch.cuda.manual_seed(self.manual_seed) + + def merge_ddp_metrics(self, metrics): + """ + Merge metrics when Distributed Data Parallel is used + """ + for metric_name in metrics.keys(): + if metric_name in ["edit_words", "nb_words", "edit_chars", "nb_chars", "edit_chars_force_len", + "edit_chars_curr", "nb_chars_curr", "ids"]: + metrics[metric_name] = self.cat_ddp_metric(metrics[metric_name]) + elif metric_name in ["nb_samples", "loss", "loss_ce", "loss_ctc", "loss_ce_end"]: + metrics[metric_name] = self.sum_ddp_metric(metrics[metric_name], average=False) + return metrics + + def sum_ddp_metric(self, metric, average=False): + """ + Sum metrics for Distributed Data Parallel + """ + sum = torch.tensor(metric[0]).to(self.device) + dist.all_reduce(sum, op=dist.ReduceOp.SUM) + if average: + sum.true_divide(dist.get_world_size()) + return [sum.item(), ] + + def cat_ddp_metric(self, metric): + """ + Concatenate metrics for Distributed Data Parallel + """ + tensor = torch.tensor(metric).unsqueeze(0).to(self.device) + res = [torch.zeros(tensor.size()).long().to(self.device) for _ in range(dist.get_world_size())] + dist.all_gather(res, tensor) + return list(torch.cat(res, dim=0).flatten().cpu().numpy()) + + @staticmethod + def cleanup(): + dist.destroy_process_group() + + def train_batch(self, batch_data, metric_names): + raise NotImplementedError + + def evaluate_batch(self, batch_data, metric_names): + raise NotImplementedError + + def init_curriculum(self): + raise NotImplementedError + + def update_curriculum(self): + raise NotImplementedError + + def add_checkpoint_info(self, load_mode="last", **kwargs): + for filename in os.listdir(self.paths["checkpoints"]): + if load_mode in filename: + checkpoint_path = os.path.join(self.paths["checkpoints"], filename) + checkpoint = torch.load(checkpoint_path) + for key in kwargs.keys(): + checkpoint[key] = kwargs[key] + torch.save(checkpoint, checkpoint_path) + return + self.save_model(self.latest_epoch, "last") + + def load_save_info(self, info_dict): + """ + Load curriculum info from saved model info + """ + if "curriculum_config" in info_dict.keys(): + self.dataset.train_dataset.curriculum_config = info_dict["curriculum_config"] + + def add_save_info(self, info_dict): + """ + Add curriculum info to model info to be saved + """ + info_dict["curriculum_config"] = self.dataset.train_dataset.curriculum_config + return info_dict \ No newline at end of file diff --git a/basic/metric_manager.py b/basic/metric_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8576863a2d26a85c2d5dbc4321323dedf870a0 --- /dev/null +++ b/basic/metric_manager.py @@ -0,0 +1,538 @@ + +from Datasets.dataset_formatters.rimes_formatter import SEM_MATCHING_TOKENS as RIMES_MATCHING_TOKENS +from Datasets.dataset_formatters.read2016_formatter import SEM_MATCHING_TOKENS as READ_MATCHING_TOKENS +from Datasets.dataset_formatters.simara_formatter import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS +import re +import networkx as nx +import editdistance +import numpy as np +from dan.post_processing import PostProcessingModuleREAD, PostProcessingModuleRIMES, PostProcessingModuleSIMARA + + +class MetricManager: + + def __init__(self, metric_names, dataset_name): + self.dataset_name = dataset_name + if "READ" in dataset_name and "page" in dataset_name: + self.post_processing_module = PostProcessingModuleREAD + self.matching_tokens = READ_MATCHING_TOKENS + self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_read + elif "RIMES" in dataset_name and "page" in dataset_name: + self.post_processing_module = PostProcessingModuleRIMES + self.matching_tokens = RIMES_MATCHING_TOKENS + self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_rimes + elif "simara" in dataset_name and "page" in dataset_name: + self.post_processing_module = PostProcessingModuleSIMARA + self.matching_tokens = SIMARA_MATCHING_TOKENS + self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_simara + else: + self.matching_tokens = dict() + + self.layout_tokens = "".join(list(self.matching_tokens.keys()) + list(self.matching_tokens.values())) + if len(self.layout_tokens) == 0: + self.layout_tokens = None + self.metric_names = metric_names + self.epoch_metrics = None + + self.linked_metrics = { + "cer": ["edit_chars", "nb_chars"], + "wer": ["edit_words", "nb_words"], + "loer": ["edit_graph", "nb_nodes_and_edges", "nb_pp_op_layout", "nb_gt_layout_token"], + "precision": ["precision", "weights"], + "map_cer_per_class": ["map_cer", ], + "layout_precision_per_class_per_threshold": ["map_cer", ], + } + + self.init_metrics() + + def init_metrics(self): + """ + Initialization of the metrics specified in metrics_name + """ + self.epoch_metrics = { + "nb_samples": list(), + "names": list(), + "ids": list(), + } + + for metric_name in self.metric_names: + if metric_name in self.linked_metrics: + for linked_metric_name in self.linked_metrics[metric_name]: + if linked_metric_name not in self.epoch_metrics.keys(): + self.epoch_metrics[linked_metric_name] = list() + else: + self.epoch_metrics[metric_name] = list() + + def update_metrics(self, batch_metrics): + """ + Add batch metrics to the metrics + """ + for key in batch_metrics.keys(): + if key in self.epoch_metrics: + self.epoch_metrics[key] += batch_metrics[key] + + def get_display_values(self, output=False): + """ + format metrics values for shell display purposes + """ + metric_names = self.metric_names.copy() + if output: + metric_names.extend(["nb_samples"]) + display_values = dict() + for metric_name in metric_names: + value = None + if output: + if metric_name in ["nb_samples", "weights"]: + value = np.sum(self.epoch_metrics[metric_name]) + elif metric_name in ["time", ]: + total_time = np.sum(self.epoch_metrics[metric_name]) + sample_time = total_time / np.sum(self.epoch_metrics["nb_samples"]) + display_values["sample_time"] = round(sample_time, 4) + value = total_time + elif metric_name == "loer": + display_values["pper"] = round(np.sum(self.epoch_metrics["nb_pp_op_layout"]) / np.sum(self.epoch_metrics["nb_gt_layout_token"]), 4) + elif metric_name == "map_cer_per_class": + value = compute_global_mAP_per_class(self.epoch_metrics["map_cer"]) + for key in value.keys(): + display_values["map_cer_" + key] = round(value[key], 4) + continue + elif metric_name == "layout_precision_per_class_per_threshold": + value = compute_global_precision_per_class_per_threshold(self.epoch_metrics["map_cer"]) + for key_class in value.keys(): + for threshold in value[key_class].keys(): + display_values["map_cer_{}_{}".format(key_class, threshold)] = round( + value[key_class][threshold], 4) + continue + if metric_name == "cer": + value = np.sum(self.epoch_metrics["edit_chars"]) / np.sum(self.epoch_metrics["nb_chars"]) + if output: + display_values["nb_chars"] = np.sum(self.epoch_metrics["nb_chars"]) + elif metric_name == "wer": + value = np.sum(self.epoch_metrics["edit_words"]) / np.sum(self.epoch_metrics["nb_words"]) + if output: + display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"]) + elif metric_name in ["loss", "loss_ctc", "loss_ce", "syn_max_lines"]: + value = np.average(self.epoch_metrics[metric_name], weights=np.array(self.epoch_metrics["nb_samples"])) + elif metric_name == "map_cer": + value = compute_global_mAP(self.epoch_metrics[metric_name]) + elif metric_name == "loer": + value = np.sum(self.epoch_metrics["edit_graph"]) / np.sum(self.epoch_metrics["nb_nodes_and_edges"]) + elif value is None: + continue + + display_values[metric_name] = round(value, 4) + return display_values + + def compute_metrics(self, values, metric_names): + metrics = { + "nb_samples": [values["nb_samples"], ], + } + for v in ["weights", "time"]: + if v in values: + metrics[v] = [values[v]] + for metric_name in metric_names: + if metric_name == "cer": + metrics["edit_chars"] = [edit_cer_from_string(u, v, self.layout_tokens) for u, v in zip(values["str_y"], values["str_x"])] + metrics["nb_chars"] = [nb_chars_cer_from_string(gt, self.layout_tokens) for gt in values["str_y"]] + elif metric_name == "wer": + split_gt = [format_string_for_wer(gt, self.layout_tokens) for gt in values["str_y"]] + split_pred = [format_string_for_wer(pred, self.layout_tokens) for pred in values["str_x"]] + metrics["edit_words"] = [edit_wer_from_formatted_split_text(gt, pred) for (gt, pred) in zip(split_gt, split_pred)] + metrics["nb_words"] = [len(gt) for gt in split_gt] + elif metric_name in ["loss_ctc", "loss_ce", "loss", "syn_max_lines", ]: + metrics[metric_name] = [values[metric_name], ] + elif metric_name == "map_cer": + pp_pred = list() + pp_score = list() + for pred, score in zip(values["str_x"], values["confidence_score"]): + pred_score = self.post_processing_module().post_process(pred, score) + pp_pred.append(pred_score[0]) + pp_score.append(pred_score[1]) + metrics[metric_name] = [compute_layout_mAP_per_class(y, x, conf, self.matching_tokens) for x, conf, y in zip(pp_pred, pp_score, values["str_y"])] + elif metric_name == "loer": + pp_pred = list() + metrics["nb_pp_op_layout"] = list() + for pred in values["str_x"]: + pp_module = self.post_processing_module() + pp_pred.append(pp_module.post_process(pred)) + metrics["nb_pp_op_layout"].append(pp_module.num_op) + metrics["nb_gt_layout_token"] = [len(keep_only_tokens(str_x, self.layout_tokens)) for str_x in values["str_x"]] + edit_and_num_items = [self.edit_and_num_edge_nodes(y, x) for x, y in zip(pp_pred, values["str_y"])] + metrics["edit_graph"], metrics["nb_nodes_and_edges"] = [ei[0] for ei in edit_and_num_items], [ei[1] for ei in edit_and_num_items] + return metrics + + def get(self, name): + return self.epoch_metrics[name] + + +def keep_only_tokens(str, tokens): + """ + Remove all but layout tokens from string + """ + return re.sub('([^' + tokens + '])', '', str) + + +def keep_all_but_tokens(str, tokens): + """ + Remove all layout tokens from string + """ + return re.sub('([' + tokens + '])', '', str) + + +def edit_cer_from_string(gt, pred, layout_tokens=None): + """ + Format and compute edit distance between two strings at character level + """ + gt = format_string_for_cer(gt, layout_tokens) + pred = format_string_for_cer(pred, layout_tokens) + return editdistance.eval(gt, pred) + + +def nb_chars_cer_from_string(gt, layout_tokens=None): + """ + Compute length after formatting of ground truth string + """ + return len(format_string_for_cer(gt, layout_tokens)) + + +def edit_wer_from_string(gt, pred, layout_tokens=None): + """ + Format and compute edit distance between two strings at word level + """ + split_gt = format_string_for_wer(gt, layout_tokens) + split_pred = format_string_for_wer(pred, layout_tokens) + return edit_wer_from_formatted_split_text(split_gt, split_pred) + + +def format_string_for_wer(str, layout_tokens): + """ + Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space + """ + str = re.sub('([\[\]{}/\\()\"\'&+*=<>?.;:,!\-—_€#%°])', r' \1 ', str) # punctuation processed as word + if layout_tokens is not None: + str = keep_all_but_tokens(str, layout_tokens) # remove layout tokens from metric + str = re.sub('([ \n])+', " ", str).strip() # keep only one space character + return str.split(" ") + + +def format_string_for_cer(str, layout_tokens): + """ + Format string for CER computation: remove layout tokens and extra spaces + """ + if layout_tokens is not None: + str = keep_all_but_tokens(str, layout_tokens) # remove layout tokens from metric + str = re.sub('([\n])+', "\n", str) # remove consecutive line breaks + str = re.sub('([ ])+', " ", str).strip() # remove consecutive spaces + return str + + +def edit_wer_from_formatted_split_text(gt, pred): + """ + Compute edit distance at word level from formatted string as list + """ + return editdistance.eval(gt, pred) + + +def extract_by_tokens(input_str, begin_token, end_token, associated_score=None, order_by_score=False): + """ + Extract list of text regions by begin and end tokens + Order the list by confidence score + """ + if order_by_score: + assert associated_score is not None + res = list() + for match in re.finditer("{}[^{}]*{}".format(begin_token, end_token, end_token), input_str): + begin, end = match.regs[0] + if order_by_score: + res.append({ + "confidence": np.mean([associated_score[begin], associated_score[end-1]]), + "content": input_str[begin+1:end-1] + }) + else: + res.append(input_str[begin+1:end-1]) + if order_by_score: + res = sorted(res, key=lambda x: x["confidence"], reverse=True) + res = [r["content"] for r in res] + return res + + +def compute_layout_precision_per_threshold(gt, pred, score, begin_token, end_token, layout_tokens, return_weight=True): + """ + Compute average precision of a given class for CER threshold from 5% to 50% with a step of 5% + """ + pred_list = extract_by_tokens(pred, begin_token, end_token, associated_score=score, order_by_score=True) + gt_list = extract_by_tokens(gt, begin_token, end_token) + pred_list = [keep_all_but_tokens(p, layout_tokens) for p in pred_list] + gt_list = [keep_all_but_tokens(gt, layout_tokens) for gt in gt_list] + precision_per_threshold = [compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold/100) for threshold in range(5, 51, 5)] + if return_weight: + return precision_per_threshold, len(gt_list) + return precision_per_threshold + + +def compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold): + """ + Compute average precision of a given class for a given CER threshold + """ + remaining_gt_list = gt_list.copy() + num_true = len(gt_list) + correct = np.zeros((len(pred_list)), dtype=np.bool) + for i, pred in enumerate(pred_list): + if len(remaining_gt_list) == 0: + break + cer_with_gt = [edit_cer_from_string(gt, pred)/nb_chars_cer_from_string(gt) for gt in remaining_gt_list] + cer, ind = np.min(cer_with_gt), np.argmin(cer_with_gt) + if cer <= threshold: + correct[i] = True + del remaining_gt_list[ind] + precision = np.cumsum(correct, dtype=np.int) / np.arange(1, len(pred_list)+1) + recall = np.cumsum(correct, dtype=np.int) / num_true + max_precision_from_recall = np.maximum.accumulate(precision[::-1])[::-1] + recall_diff = (recall - np.concatenate([np.array([0, ]), recall[:-1]])) + P = np.sum(recall_diff * max_precision_from_recall) + return P + + +def compute_layout_mAP_per_class(gt, pred, score, tokens): + """ + Compute the mAP_cer for each class for a given sample + """ + layout_tokens = "".join(list(tokens.keys())) + AP_per_class = dict() + for token in tokens.keys(): + if token in gt: + AP_per_class[token] = compute_layout_precision_per_threshold(gt, pred, score, token, tokens[token], layout_tokens=layout_tokens) + return AP_per_class + + +def compute_global_mAP(list_AP_per_class): + """ + Compute the global mAP_cer for several samples + """ + weights_per_doc = list() + mAP_per_doc = list() + for doc_AP_per_class in list_AP_per_class: + APs = np.array([np.mean(doc_AP_per_class[key][0]) for key in doc_AP_per_class.keys()]) + weights = np.array([doc_AP_per_class[key][1] for key in doc_AP_per_class.keys()]) + if np.sum(weights) == 0: + mAP_per_doc.append(0) + else: + mAP_per_doc.append(np.average(APs, weights=weights)) + weights_per_doc.append(np.sum(weights)) + if np.sum(weights_per_doc) == 0: + return 0 + return np.average(mAP_per_doc, weights=weights_per_doc) + + +def compute_global_mAP_per_class(list_AP_per_class): + """ + Compute the mAP_cer per class for several samples + """ + mAP_per_class = dict() + for doc_AP_per_class in list_AP_per_class: + for key in doc_AP_per_class.keys(): + if key not in mAP_per_class: + mAP_per_class[key] = { + "AP": list(), + "weights": list() + } + mAP_per_class[key]["AP"].append(np.mean(doc_AP_per_class[key][0])) + mAP_per_class[key]["weights"].append(doc_AP_per_class[key][1]) + for key in mAP_per_class.keys(): + mAP_per_class[key] = np.average(mAP_per_class[key]["AP"], weights=mAP_per_class[key]["weights"]) + return mAP_per_class + + +def compute_global_precision_per_class_per_threshold(list_AP_per_class): + """ + Compute the mAP_cer per class and per threshold for several samples + """ + mAP_per_class = dict() + for doc_AP_per_class in list_AP_per_class: + for key in doc_AP_per_class.keys(): + if key not in mAP_per_class: + mAP_per_class[key] = dict() + for threshold in range(5, 51, 5): + mAP_per_class[key][threshold] = { + "precision": list(), + "weights": list() + } + for i, threshold in enumerate(range(5, 51, 5)): + mAP_per_class[key][threshold]["precision"].append(np.mean(doc_AP_per_class[key][0][i])) + mAP_per_class[key][threshold]["weights"].append(doc_AP_per_class[key][1]) + for key_class in mAP_per_class.keys(): + for threshold in mAP_per_class[key_class]: + mAP_per_class[key_class][threshold] = np.average(mAP_per_class[key_class][threshold]["precision"], weights=mAP_per_class[key_class][threshold]["weights"]) + return mAP_per_class + + +def str_to_graph_read(str): + """ + Compute graph from string of layout tokens for the READ 2016 dataset at single-page and double-page levels + """ + begin_layout_tokens = "".join(list(READ_MATCHING_TOKENS.keys())) + layout_token_sequence = keep_only_tokens(str, begin_layout_tokens) + g = nx.DiGraph() + g.add_node("D", type="document", level=4, page=0) + num = { + "ⓟ": 0, + "â“": 0, + "â“‘": 0, + "â“": 0, + "â“¢": 0 + } + previous_top_level_node = None + previous_middle_level_node = None + previous_low_level_node = None + for ind, c in enumerate(layout_token_sequence): + num[c] += 1 + if c == "ⓟ": + node_name = "P_{}".format(num[c]) + g.add_node(node_name, type="page", level=3, page=num["ⓟ"]) + g.add_edge("D", node_name) + if previous_top_level_node: + g.add_edge(previous_top_level_node, node_name) + previous_top_level_node = node_name + previous_middle_level_node = None + previous_low_level_node = None + if c in "â“â“¢": + node_name = "{}_{}".format("N" if c == "â“" else "S", num[c]) + g.add_node(node_name, type="number" if c == "â“" else "section", level=2, page=num["ⓟ"]) + g.add_edge(previous_top_level_node, node_name) + if previous_middle_level_node: + g.add_edge(previous_middle_level_node, node_name) + previous_middle_level_node = node_name + previous_low_level_node = None + if c in "â“â“‘": + node_name = "{}_{}".format("A" if c == "â“" else "B", num[c]) + g.add_node(node_name, type="annotation" if c == "â“" else "body", level=1, page=num["ⓟ"]) + g.add_edge(previous_middle_level_node, node_name) + if previous_low_level_node: + g.add_edge(previous_low_level_node, node_name) + previous_low_level_node = node_name + return g + + +def str_to_graph_rimes(str): + """ + Compute graph from string of layout tokens for the RIMES dataset at page level + """ + begin_layout_tokens = "".join(list(RIMES_MATCHING_TOKENS.keys())) + layout_token_sequence = keep_only_tokens(str, begin_layout_tokens) + g = nx.DiGraph() + g.add_node("D", type="document", level=2, page=0) + token_name_dict = { + "â“‘": "B", + "ⓞ": "O", + "â“¡": "R", + "â“¢": "S", + "ⓦ": "W", + "ⓨ": "Y", + "ⓟ": "P" + } + num = dict() + previous_node = None + for token in begin_layout_tokens: + num[token] = 0 + for ind, c in enumerate(layout_token_sequence): + num[c] += 1 + node_name = "{}_{}".format(token_name_dict[c], num[c]) + g.add_node(node_name, type=token_name_dict[c], level=1, page=0) + g.add_edge("D", node_name) + if previous_node: + g.add_edge(previous_node, node_name) + previous_node = node_name + return g + + +def str_to_graph_simara(str): + """ + Compute graph from string of layout tokens for the SIMARA dataset at page level + """ + begin_layout_tokens = "".join(list(SIMARA_MATCHING_TOKENS.keys())) + layout_token_sequence = keep_only_tokens(str, begin_layout_tokens) + g = nx.DiGraph() + g.add_node("D", type="document", level=2, page=0) + token_name_dict = { + "ⓘ": "I", + "â““": "D", + "â“¢": "S", + "â“’": "C", + "ⓟ": "P", + "â“": "A" + } + num = dict() + previous_node = None + for token in begin_layout_tokens: + num[token] = 0 + for ind, c in enumerate(layout_token_sequence): + num[c] += 1 + node_name = "{}_{}".format(token_name_dict[c], num[c]) + g.add_node(node_name, type=token_name_dict[c], level=1, page=0) + g.add_edge("D", node_name) + if previous_node: + g.add_edge(previous_node, node_name) + previous_node = node_name + return g + + +def graph_edit_distance_by_page_read(g1, g2): + """ + Compute graph edit distance page by page for the READ 2016 dataset + """ + num_pages_g1 = len([n for n in g1.nodes().items() if n[1]["level"] == 3]) + num_pages_g2 = len([n for n in g2.nodes().items() if n[1]["level"] == 3]) + page_graphs_1 = [g1.subgraph([n[0] for n in g1.nodes().items() if n[1]["page"] == num_page]) for num_page in range(1, num_pages_g1+1)] + page_graphs_2 = [g2.subgraph([n[0] for n in g2.nodes().items() if n[1]["page"] == num_page]) for num_page in range(1, num_pages_g2+1)] + edit = 0 + for i in range(max(len(page_graphs_1), len(page_graphs_2))): + page_1 = page_graphs_1[i] if i < len(page_graphs_1) else nx.DiGraph() + page_2 = page_graphs_2[i] if i < len(page_graphs_2) else nx.DiGraph() + edit += graph_edit_distance(page_1, page_2) + return edit + + +def graph_edit_distance(g1, g2): + """ + Compute graph edit distance between two graphs + """ + for v in nx.optimize_graph_edit_distance(g1, g2, + node_ins_cost=lambda node: 1, + node_del_cost=lambda node: 1, + node_subst_cost=lambda node1, node2: 0 if node1["type"] == node2["type"] else 1, + edge_ins_cost=lambda edge: 1, + edge_del_cost=lambda edge: 1, + edge_subst_cost=lambda edge1, edge2: 0 if edge1 == edge2 else 1 + ): + new_edit = v + return new_edit + + +def edit_and_num_items_for_ged_from_str_read(str_gt, str_pred): + """ + Compute graph edit distance and num nodes/edges for normalized graph edit distance + For the READ 2016 dataset + """ + g_gt = str_to_graph_read(str_gt) + g_pred = str_to_graph_read(str_pred) + return graph_edit_distance_by_page_read(g_gt, g_pred), g_gt.number_of_nodes() + g_gt.number_of_edges() + + +def edit_and_num_items_for_ged_from_str_rimes(str_gt, str_pred): + """ + Compute graph edit distance and num nodes/edges for normalized graph edit distance + For the RIMES dataset + """ + g_gt = str_to_graph_rimes(str_gt) + g_pred = str_to_graph_rimes(str_pred) + return graph_edit_distance(g_gt, g_pred), g_gt.number_of_nodes() + g_gt.number_of_edges() + + +def edit_and_num_items_for_ged_from_str_simara(str_gt, str_pred): + """ + Compute graph edit distance and num nodes/edges for normalized graph edit distance + For the SIMARA dataset + """ + g_gt = str_to_graph_simara(str_gt) + g_pred = str_to_graph_simara(str_pred) + return graph_edit_distance(g_gt, g_pred), g_gt.number_of_nodes() + g_gt.number_of_edges() diff --git a/basic/scheduler.py b/basic/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..6c875c1da2f57ab0306635cbc77309c2b83af116 --- /dev/null +++ b/basic/scheduler.py @@ -0,0 +1,51 @@ + +from torch.nn import Dropout, Dropout2d +import numpy as np + + +class DropoutScheduler: + + def __init__(self, models, function, T=1e5): + """ + T: number of gradient updates to converge + """ + + self.teta_list = list() + self.init_teta_list(models) + self.function = function + self.T = T + self.step_num = 0 + + def step(self): + self.step(1) + + def step(self, num): + self.step_num += num + + def init_teta_list(self, models): + for model_name in models.keys(): + self.init_teta_list_module(models[model_name]) + + def init_teta_list_module(self, module): + for child in module.children(): + if isinstance(child, Dropout) or isinstance(child, Dropout2d): + self.teta_list.append([child, child.p]) + else: + self.init_teta_list_module(child) + + def update_dropout_rate(self): + for (module, p) in self.teta_list: + module.p = self.function(p, self.step_num, self.T) + + +def exponential_dropout_scheduler(dropout_rate, step, max_step): + return dropout_rate * (1 - np.exp(-10 * step / max_step)) + + +def exponential_scheduler(init_value, end_value, step, max_step): + step = min(step, max_step-1) + return init_value - (init_value - end_value) * (1 - np.exp(-10*step/max_step)) + + +def linear_scheduler(init_value, end_value, step, max_step): + return init_value + step * (end_value - init_value) / max_step \ No newline at end of file diff --git a/basic/utils.py b/basic/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac50e57ab1136a2f2e1d209c39271406316d0479 --- /dev/null +++ b/basic/utils.py @@ -0,0 +1,179 @@ + + +import numpy as np +import torch +from torch.distributions.uniform import Uniform +import cv2 + + +def randint(low, high): + """ + call torch.randint to preserve random among dataloader workers + """ + return int(torch.randint(low, high, (1, ))) + + +def rand(): + """ + call torch.rand to preserve random among dataloader workers + """ + return float(torch.rand((1, ))) + + +def rand_uniform(low, high): + """ + call torch uniform to preserve random among dataloader workers + """ + return float(Uniform(low, high).sample()) + + +def pad_sequences_1D(data, padding_value): + """ + Pad data with padding_value to get same length + """ + x_lengths = [len(x) for x in data] + longest_x = max(x_lengths) + padded_data = np.ones((len(data), longest_x)).astype(np.int32) * padding_value + for i, x_len in enumerate(x_lengths): + padded_data[i, :x_len] = data[i][:x_len] + return padded_data + + +def resize_max(img, max_width=None, max_height=None): + if max_width is not None and img.shape[1] > max_width: + ratio = max_width / img.shape[1] + new_h = int(np.floor(ratio * img.shape[0])) + new_w = int(np.floor(ratio * img.shape[1])) + img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + if max_height is not None and img.shape[0] > max_height: + ratio = max_height / img.shape[0] + new_h = int(np.floor(ratio * img.shape[0])) + new_w = int(np.floor(ratio * img.shape[1])) + img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + return img + + +def pad_images(data, padding_value, padding_mode="br"): + """ + data: list of numpy array + mode: "br"/"tl"/"random" (bottom-right, top-left, random) + """ + x_lengths = [x.shape[0] for x in data] + y_lengths = [x.shape[1] for x in data] + longest_x = max(x_lengths) + longest_y = max(y_lengths) + padded_data = np.ones((len(data), longest_x, longest_y, data[0].shape[2])) * padding_value + for i, xy_len in enumerate(zip(x_lengths, y_lengths)): + x_len, y_len = xy_len + if padding_mode == "br": + padded_data[i, :x_len, :y_len, ...] = data[i] + elif padding_mode == "tl": + padded_data[i, -x_len:, -y_len:, ...] = data[i] + elif padding_mode == "random": + xmax = longest_x - x_len + ymax = longest_y - y_len + xi = randint(0, xmax) if xmax >= 1 else 0 + yi = randint(0, ymax) if ymax >= 1 else 0 + padded_data[i, xi:xi+x_len, yi:yi+y_len, ...] = data[i] + else: + raise NotImplementedError("Undefined padding mode: {}".format(padding_mode)) + return padded_data + + +def pad_image(image, padding_value, new_height=None, new_width=None, pad_width=None, pad_height=None, padding_mode="br", return_position=False): + """ + data: list of numpy array + mode: "br"/"tl"/"random" (bottom-right, top-left, random) + """ + if pad_width is not None and new_width is not None: + raise NotImplementedError("pad_with and new_width are not compatible") + if pad_height is not None and new_height is not None: + raise NotImplementedError("pad_height and new_height are not compatible") + + h, w, c = image.shape + pad_width = pad_width if pad_width is not None else max(0, new_width - w) if new_width is not None else 0 + pad_height = pad_height if pad_height is not None else max(0, new_height - h) if new_height is not None else 0 + + if not (pad_width == 0 and pad_height == 0): + padded_image = np.ones((h+pad_height, w+pad_width, c)) * padding_value + if padding_mode == "br": + hi, wi = 0, 0 + elif padding_mode == "tl": + hi, wi = pad_height, pad_width + elif padding_mode == "random": + hi = randint(0, pad_height) if pad_height >= 1 else 0 + wi = randint(0, pad_width) if pad_width >= 1 else 0 + else: + raise NotImplementedError("Undefined padding mode: {}".format(padding_mode)) + padded_image[hi:hi + h, wi:wi + w, ...] = image + output = padded_image + else: + hi, wi = 0, 0 + output = image + + if return_position: + return output, [[hi, hi+h], [wi, wi+w]] + return output + + +def pad_image_width_right(img, new_width, padding_value): + """ + Pad img to right side with padding value to reach new_width as width + """ + h, w, c = img.shape + pad_width = max((new_width - w), 0) + pad_right = np.ones((h, pad_width, c), dtype=img.dtype) * padding_value + img = np.concatenate([img, pad_right], axis=1) + return img + + +def pad_image_width_left(img, new_width, padding_value): + """ + Pad img to left side with padding value to reach new_width as width + """ + h, w, c = img.shape + pad_width = max((new_width - w), 0) + pad_left = np.ones((h, pad_width, c), dtype=img.dtype) * padding_value + img = np.concatenate([pad_left, img], axis=1) + return img + + +def pad_image_width_random(img, new_width, padding_value, max_pad_left_ratio=1): + """ + Randomly pad img to left and right sides with padding value to reach new_width as width + """ + h, w, c = img.shape + pad_width = max((new_width - w), 0) + max_pad_left = int(max_pad_left_ratio*pad_width) + pad_left = randint(0, min(pad_width, max_pad_left)) if pad_width != 0 and max_pad_left > 0 else 0 + pad_right = pad_width - pad_left + pad_left = np.ones((h, pad_left, c), dtype=img.dtype) * padding_value + pad_right = np.ones((h, pad_right, c), dtype=img.dtype) * padding_value + img = np.concatenate([pad_left, img, pad_right], axis=1) + return img + + +def pad_image_height_random(img, new_height, padding_value, max_pad_top_ratio=1): + """ + Randomly pad img top and bottom sides with padding value to reach new_width as width + """ + h, w, c = img.shape + pad_height = max((new_height - h), 0) + max_pad_top = int(max_pad_top_ratio*pad_height) + pad_top = randint(0, min(pad_height, max_pad_top)) if pad_height != 0 and max_pad_top > 0 else 0 + pad_bottom = pad_height - pad_top + pad_top = np.ones((pad_top, w, c), dtype=img.dtype) * padding_value + pad_bottom = np.ones((pad_bottom, w, c), dtype=img.dtype) * padding_value + img = np.concatenate([pad_top, img, pad_bottom], axis=0) + return img + + +def pad_image_height_bottom(img, new_height, padding_value): + """ + Pad img to bottom side with padding value to reach new_height as height + """ + h, w, c = img.shape + pad_height = max((new_height - h), 0) + pad_bottom = np.ones((pad_height, w, c)) * padding_value + img = np.concatenate([img, pad_bottom], axis=0) + return img diff --git a/dan/cli.py b/dan/cli.py index e633ceb44e09fdfce699082d4b72f0c3c42cdfa6..7c25144ea4255fb812f23f01c50a42a2ef4e8066 100644 --- a/dan/cli.py +++ b/dan/cli.py @@ -16,7 +16,6 @@ def get_parser(): add_generate_parser(subcommands) return parser - def main(): parser = get_parser() args = vars(parser.parse_args()) diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..72ab461b8214a14b144cc6c2edc725cad38cf6bc --- /dev/null +++ b/dan/datasets/extract/utils.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" + The utils module + ====================== +""" + +import argparse + + +def get_cli_args(): + """ + Get the command-line arguments. + :return: The command-line arguments. + """ + parser = argparse.ArgumentParser( + description="Arkindex DAN Training Label Generation" + ) + + # Required arguments. + parser.add_argument( + "--corpus", + type=str, + help="Name of the corpus from which the data will be retrieved.", + required=True, + ) + parser.add_argument( + "--parents-types", + nargs="+", + type=str, + help="Type of parents of the elements.", + required=True, + ) + parser.add_argument( + "--output-dir", + type=str, + help="Path to the output directory.", + required=True, + ) + + # Optional arguments. + parser.add_argument( + "--parents-names", + nargs="+", + type=str, + help="Names of parents of the elements.", + default=None, + ) + return parser.parse_args() diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py index 64633ebe53445b4b1768e773ba76c9195e079f77..b35b5248fbdf08442cb8a7f27bc974e31c4b0559 100644 --- a/dan/manager/dataset.py +++ b/dan/manager/dataset.py @@ -1,4 +1,13 @@ +<<<<<<<< HEAD:dan/manager/dataset.py # -*- coding: utf-8 -*- +======== +import torch +import random +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from basic.transforms import apply_data_augmentation +from Datasets.dataset_formatters.utils_dataset import natural_sort +>>>>>>>> 2abbb88 (refactoring + packaging):basic/generic_dataset_manager.py import os import pickle import random diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 09df425a5743c57eb582426bd16902a3b5aa60ad..e6bc9ed14a8b8ffbd8e27a5888bf459c25f47b4f 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -18,7 +18,10 @@ def add_document_parser(subcommands) -> None: parser = subcommands.add_parser( "document", description=__doc__, +<<<<<<< HEAD help=__doc__, +======= +>>>>>>> 2abbb88 (refactoring + packaging) ) parser.set_defaults(func=run) diff --git a/dan/ocr/line/utils.py b/dan/ocr/line/utils.py index 01ed68be78629637eb19562b5dbca26e848971bf..71061c5f3d4963732cf387ff56fe5d7e2b7f65e0 100644 --- a/dan/ocr/line/utils.py +++ b/dan/ocr/line/utils.py @@ -1,4 +1,14 @@ +<<<<<<<< HEAD:dan/ocr/line/utils.py # -*- coding: utf-8 -*- +======== + +from basic.metric_manager import MetricManager +from OCR.ocr_manager import OCRManager +from dan.ocr_utils import LM_ind_to_str +import torch +from torch.cuda.amp import autocast +from torch.nn import CTCLoss +>>>>>>>> 2abbb88 (refactoring + packaging):OCR/line_OCR/ctc/trainer_line_ctc.py import re import time diff --git a/dan/transforms.py b/dan/transforms.py index b33489c63cc88e42f43077286f550af587b987aa..0ef0a70bd0f3e94dc6acf3169917e8f93b6a71b8 100644 --- a/dan/transforms.py +++ b/dan/transforms.py @@ -1,4 +1,18 @@ +<<<<<<<< HEAD:dan/transforms.py # -*- coding: utf-8 -*- +======== + +import numpy as np +from numpy import random +from PIL import Image, ImageOps +from cv2 import erode, dilate, normalize +import cv2 +import math +from basic.utils import randint, rand_uniform, rand +from torchvision.transforms import RandomPerspective, RandomCrop, ColorJitter, GaussianBlur, RandomRotation +from torchvision.transforms.functional import InterpolationMode + +>>>>>>>> 2abbb88 (refactoring + packaging):basic/transforms.py """ Each transform class defined here takes as input a PIL Image and returns the modified PIL Image """