diff --git a/Datasets/dataset_formatters/generic_dataset_formatter.py b/Datasets/dataset_formatters/generic_dataset_formatter.py deleted file mode 100644 index ae55af026c841af4a9ab8ee08e3437bf9d16c034..0000000000000000000000000000000000000000 --- a/Datasets/dataset_formatters/generic_dataset_formatter.py +++ /dev/null @@ -1,89 +0,0 @@ -import os -import shutil -import tarfile -import pickle -import re -from PIL import Image -import numpy as np - - -class DatasetFormatter: - """ - Global pipeline/functions for dataset formatting - """ - - def __init__(self, dataset_name, level, extra_name="", set_names=["train", "valid", "test"]): - self.dataset_name = dataset_name - self.level = level - self.set_names = set_names - self.target_fold_path = os.path.join( - "Datasets", "formatted", "{}_{}{}".format(dataset_name, level, extra_name)) - self.map_datasets_files = dict() - self.extract_with_dirname = False - - def format(self): - self.init_format() - self.map_datasets_files[self.dataset_name][self.level]["format_function"]() - self.end_format() - - def init_format(self): - """ - Load and extracts needed files - """ - os.makedirs(self.target_fold_path, exist_ok=True) - - for set_name in self.set_names: - os.makedirs(os.path.join(self.target_fold_path, set_name), exist_ok=True) - - -class OCRDatasetFormatter(DatasetFormatter): - """ - Specific pipeline/functions for OCR/HTR dataset formatting - """ - - def __init__(self, source_dataset, level, extra_name="", set_names=["train", "valid", "test"]): - super(OCRDatasetFormatter, self).__init__(source_dataset, level, extra_name, set_names) - self.charset = set() - self.gt = dict() - for set_name in set_names: - self.gt[set_name] = dict() - - def format_text_label(self, label): - """ - Remove extra space or line break characters - """ - temp = re.sub("(\n)+", '\n', label) - return re.sub("( )+", ' ', temp).strip(" \n") - - def load_resize_save(self, source_path, target_path): - """ - Load image, apply resolution modification and save it - """ - shutil.copyfile(source_path, target_path) - - def resize(self, img, source_dpi, target_dpi): - """ - Apply resolution modification to image - """ - if source_dpi == target_dpi: - return img - if isinstance(img, np.ndarray): - h, w = img.shape[:2] - img = Image.fromarray(img) - else: - w, h = img.size - ratio = target_dpi / source_dpi - img = img.resize((int(w*ratio), int(h*ratio)), Image.BILINEAR) - return np.array(img) - - def end_format(self): - """ - Save label and charset files - """ - with open(os.path.join(self.target_fold_path, "labels.pkl"), "wb") as f: - pickle.dump({ - "ground_truth": self.gt, - "charset": sorted(list(self.charset)), - }, f) - with open(os.path.join(self.target_fold_path, "charset.pkl"), "wb") as f: - pickle.dump(sorted(list(self.charset)), f) diff --git a/Datasets/dataset_formatters/simara_formatter.py b/Datasets/dataset_formatters/simara_formatter.py deleted file mode 100644 index c60eb30f94cf569d3b2d9d9197a16756ee994c01..0000000000000000000000000000000000000000 --- a/Datasets/dataset_formatters/simara_formatter.py +++ /dev/null @@ -1,104 +0,0 @@ -from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter -import os -import numpy as np -from Datasets.dataset_formatters.utils_dataset import natural_sort -from PIL import Image -import xml.etree.ElementTree as ET -import re -from tqdm import tqdm - -# Layout string to token -SEM_MATCHING_TOKENS_STR = { - "INTITULE": "ⓘ", - "DATE": "â““", - "COTE_SERIE": "â“¢", - "ANALYSE_COMPL": "â“’", - "PRECISIONS_SUR_COTE": "ⓟ", - "COTE_ARTICLE": "â“" -} - -# Layout begin-token to end-token -SEM_MATCHING_TOKENS = { - "ⓘ": "â’¾", - "â““": "â’¹", - "â“¢": "Ⓢ", - "â“’": "â’¸", - "ⓟ": "â“…", - "â“": "â’¶" -} - -class SimaraDatasetFormatter(OCRDatasetFormatter): - def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=True): - super(SimaraDatasetFormatter, self).__init__("simara", level, "_sem" if sem_token else "", set_names) - - self.dpi = dpi - self.sem_token = sem_token - self.map_datasets_files.update({ - "simara": { - # (1,050 for train, 100 for validation and 100 for test) - "page": { - "format_function": self.format_simara_page, - }, - } - }) - self.matching_tokens_str = SEM_MATCHING_TOKENS_STR - self.matching_tokens = SEM_MATCHING_TOKENS - - def preformat_simara_page(self): - """ - Extract all information from dataset and correct some annotations - """ - dataset = { - "train": list(), - "valid": list(), - "test": list() - } - img_folder_path = os.path.join("Datasets", "raw", "simara", "images") - labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels") - sem_labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels_sem") - train_files = [ - os.path.join(labels_folder_path, 'train', name) - for name in os.listdir(os.path.join(sem_labels_folder_path, 'train'))] - valid_files = [ - os.path.join(labels_folder_path, 'valid', name) - for name in os.listdir(os.path.join(sem_labels_folder_path, 'valid'))] - test_files = [ - os.path.join(labels_folder_path, 'test', name) - for name in os.listdir(os.path.join(sem_labels_folder_path, 'test'))] - for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]): - for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)): - with open(label_file, 'r') as f: - text = f.read() - with open(label_file.replace('labels', 'labels_sem'), 'r') as f: - sem_text = f.read() - dataset[set_name].append({ - "img_path": os.path.join( - img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'jpg')), - "label": text, - "sem_label": sem_text, - }) - print(dataset['test'], len(dataset['test'])) - return dataset - - def format_simara_page(self): - """ - Format simara page dataset - """ - dataset = self.preformat_simara_page() - for set_name in self.set_names: - fold = os.path.join(self.target_fold_path, set_name) - for sample in tqdm(dataset[set_name], desc='Formatting '+set_name): - new_name = sample['img_path'].split('/')[-1] - new_img_path = os.path.join(fold, new_name) - self.load_resize_save(sample["img_path"], new_img_path)#, 300, self.dpi) - page = { - "text": sample["label"] if not self.sem_token else sample["sem_label"], - } - self.charset = self.charset.union(set(page["text"])) - self.gt[set_name][new_name] = page - - -if __name__ == "__main__": - - SimaraDatasetFormatter("page", sem_token=True).format() - #SimaraDatasetFormatter("page", sem_token=False).format() diff --git a/OCR/line_OCR/ctc/models_line_ctc.py b/OCR/line_OCR/ctc/models_line_ctc.py deleted file mode 100644 index c13c072ebd10e810d41b8aab1e318c2170637ea0..0000000000000000000000000000000000000000 --- a/OCR/line_OCR/ctc/models_line_ctc.py +++ /dev/null @@ -1,19 +0,0 @@ - -from torch.nn.functional import log_softmax -from torch.nn import AdaptiveMaxPool2d, Conv1d -from torch.nn import Module - - -class Decoder(Module): - def __init__(self, params): - super(Decoder, self).__init__() - - self.vocab_size = params["vocab_size"] - - self.ada_pool = AdaptiveMaxPool2d((1, None)) - self.end_conv = Conv1d(in_channels=params["enc_size"], out_channels=self.vocab_size+1, kernel_size=1) - - def forward(self, x): - x = self.ada_pool(x).squeeze(2) - x = self.end_conv(x) - return log_softmax(x, dim=1) diff --git a/OCR/ocr_manager.py b/OCR/ocr_manager.py deleted file mode 100644 index e34b36c0a0793868c5a8f590b20510db72bc0681..0000000000000000000000000000000000000000 --- a/OCR/ocr_manager.py +++ /dev/null @@ -1,67 +0,0 @@ -from basic.generic_training_manager import GenericTrainingManager -import os -from PIL import Image -import pickle - - -class OCRManager(GenericTrainingManager): - def __init__(self, params): - super(OCRManager, self).__init__(params) - self.params["model_params"]["vocab_size"] = len(self.dataset.charset) - - def generate_syn_line_dataset(self, name): - """ - Generate synthetic line dataset from currently loaded dataset - """ - dataset_name = list(self.params['dataset_params']["datasets"].keys())[0] - path = os.path.join(os.path.dirname(self.params['dataset_params']["datasets"][dataset_name]), name) - os.makedirs(path, exist_ok=True) - charset = set() - dataset = None - gt = { - "train": dict(), - "valid": dict(), - "test": dict() - } - for set_name in ["train", "valid", "test"]: - set_path = os.path.join(path, set_name) - os.makedirs(set_path, exist_ok=True) - if set_name == "train": - dataset = self.dataset.train_dataset - elif set_name == "valid": - dataset = self.dataset.valid_datasets["{}-valid".format(dataset_name)] - elif set_name == "test": - self.dataset.generate_test_loader("{}-test".format(dataset_name), [(dataset_name, "test"), ]) - dataset = self.dataset.test_datasets["{}-test".format(dataset_name)] - - samples = list() - for sample in dataset.samples: - for line_label in sample["label"].split("\n"): - for chunk in [line_label[i:i+100] for i in range(0, len(line_label), 100)]: - charset = charset.union(set(chunk)) - if len(chunk) > 0: - samples.append({ - "path": sample["path"], - "label": chunk, - "nb_cols": 1, - }) - - for i, sample in enumerate(samples): - ext = sample['path'].split(".")[-1] - img_name = "{}_{}.{}".format(set_name, i, ext) - img_path = os.path.join(set_path, img_name) - - img = dataset.generate_typed_text_line_image(sample["label"]) - Image.fromarray(img).save(img_path) - gt[set_name][img_name] = { - "text": sample["label"], - "nb_cols": sample["nb_cols"] if "nb_cols" in sample else 1 - } - if "line_label" in sample: - gt[set_name][img_name]["lines"] = sample["line_label"] - - with open(os.path.join(path, "labels.pkl"), "wb") as f: - pickle.dump({ - "ground_truth": gt, - "charset": sorted(list(charset)), - }, f) \ No newline at end of file diff --git a/basic/generic_training_manager.py b/basic/generic_training_manager.py deleted file mode 100644 index a8f41b28274678b23083123e33d2343bb6b942b9..0000000000000000000000000000000000000000 --- a/basic/generic_training_manager.py +++ /dev/null @@ -1,706 +0,0 @@ -import torch -import os -import sys -import copy -import json -import torch.distributed as dist -import torch.multiprocessing as mp -import random -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from torch.nn.init import kaiming_uniform_ -from tqdm import tqdm -from time import time -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.cuda.amp import GradScaler -from basic.metric_manager import MetricManager -from basic.scheduler import DropoutScheduler -from datetime import date - - -class GenericTrainingManager: - - def __init__(self, params): - self.type = None - self.is_master = False - self.params = params - self.dropout_scheduler = None - self.models = {} - self.begin_time = None - self.dataset = None - self.dataset_name = list(self.params["dataset_params"]["datasets"].values())[0] - self.paths = None - self.latest_step = 0 - self.latest_epoch = -1 - self.latest_batch = 0 - self.total_batch = 0 - self.grad_acc_step = 0 - self.latest_train_metrics = dict() - self.latest_valid_metrics = dict() - self.curriculum_info = dict() - self.curriculum_info["latest_valid_metrics"] = dict() - self.phase = None - self.max_mem_usage_by_epoch = list() - self.losses = list() - self.lr_values = list() - - self.scaler = None - - self.optimizers = dict() - self.optimizers_named_params_by_group = dict() - self.lr_schedulers = dict() - self.best = None - self.writer = None - self.metric_manager = dict() - - self.init_hardware_config() - self.init_paths() - self.load_dataset() - self.params["model_params"]["use_amp"] = self.params["training_params"]["use_amp"] - - def init_paths(self): - """ - Create output folders for results and checkpoints - """ - output_path = os.path.join("outputs", self.params["training_params"]["output_folder"]) - os.makedirs(output_path, exist_ok=True) - checkpoints_path = os.path.join(output_path, "checkpoints") - os.makedirs(checkpoints_path, exist_ok=True) - results_path = os.path.join(output_path, "results") - os.makedirs(results_path, exist_ok=True) - - self.paths = { - "results": results_path, - "checkpoints": checkpoints_path, - "output_folder": output_path - } - - def load_dataset(self): - """ - Load datasets, data samplers and data loaders - """ - self.params["dataset_params"]["use_ddp"] = self.params["training_params"]["use_ddp"] - self.params["dataset_params"]["batch_size"] = self.params["training_params"]["batch_size"] - if "valid_batch_size" in self.params["training_params"]: - self.params["dataset_params"]["valid_batch_size"] = self.params["training_params"]["valid_batch_size"] - if "test_batch_size" in self.params["training_params"]: - self.params["dataset_params"]["test_batch_size"] = self.params["training_params"]["test_batch_size"] - self.params["dataset_params"]["num_gpu"] = self.params["training_params"]["nb_gpu"] - self.params["dataset_params"]["worker_per_gpu"] = 4 if "worker_per_gpu" not in self.params["dataset_params"] else self.params["dataset_params"]["worker_per_gpu"] - self.dataset = self.params["dataset_params"]["dataset_manager"](self.params["dataset_params"]) - self.dataset.load_datasets() - self.dataset.load_ddp_samplers() - self.dataset.load_dataloaders() - - def init_hardware_config(self): - # Debug mode - if self.params["training_params"]["force_cpu"]: - self.params["training_params"]["use_ddp"] = False - self.params["training_params"]["use_amp"] = False - # Manage Distributed Data Parallel & GPU usage - self.manual_seed = 1111 if "manual_seed" not in self.params["training_params"].keys() else \ - self.params["training_params"]["manual_seed"] - self.ddp_config = { - "master": self.params["training_params"]["use_ddp"] and self.params["training_params"]["ddp_rank"] == 0, - "address": "localhost" if "ddp_addr" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_addr"], - "port": "11111" if "ddp_port" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_port"], - "backend": "nccl" if "ddp_backend" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_backend"], - "rank": self.params["training_params"]["ddp_rank"], - } - self.is_master = self.ddp_config["master"] or not self.params["training_params"]["use_ddp"] - if self.params["training_params"]["force_cpu"]: - self.device = "cpu" - else: - if self.params["training_params"]["use_ddp"]: - self.device = torch.device(self.ddp_config["rank"]) - self.params["dataset_params"]["ddp_rank"] = self.ddp_config["rank"] - self.launch_ddp() - else: - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self.params["model_params"]["device"] = self.device.type - # Print GPU info - # global - if (self.params["training_params"]["use_ddp"] and self.ddp_config["master"]) or not self.params["training_params"]["use_ddp"]: - print("##################") - print("Available GPUS: {}".format(self.params["training_params"]["nb_gpu"])) - for i in range(self.params["training_params"]["nb_gpu"]): - print("Rank {}: {} {}".format(i, torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i))) - print("##################") - # local - print("Local GPU:") - if self.device != "cpu": - print("Rank {}: {} {}".format(self.params["training_params"]["ddp_rank"], torch.cuda.get_device_name(), torch.cuda.get_device_properties(self.device))) - else: - print("WORKING ON CPU !\n") - print("##################") - - def load_model(self, reset_optimizer=False, strict=True): - """ - Load model weights from scratch or from checkpoints - """ - # Instantiate Model - for model_name in self.params["model_params"]["models"].keys(): - self.models[model_name] = self.params["model_params"]["models"][model_name](self.params["model_params"]) - self.models[model_name].to(self.device) # To GPU or CPU - # make the model compatible with Distributed Data Parallel if used - if self.params["training_params"]["use_ddp"]: - self.models[model_name] = DDP(self.models[model_name], [self.ddp_config["rank"]]) - - # Handle curriculum dropout - if "dropout_scheduler" in self.params["model_params"]: - func = self.params["model_params"]["dropout_scheduler"]["function"] - T = self.params["model_params"]["dropout_scheduler"]["T"] - self.dropout_scheduler = DropoutScheduler(self.models, func, T) - - self.scaler = GradScaler(enabled=self.params["training_params"]["use_amp"]) - - # Check if checkpoint exists - checkpoint = self.get_checkpoint() - if checkpoint is not None: - self.load_existing_model(checkpoint, strict=strict) - else: - self.init_new_model() - - self.load_optimizers(checkpoint, reset_optimizer=reset_optimizer) - - if self.is_master: - print("LOADED EPOCH: {}\n".format(self.latest_epoch), flush=True) - - def get_checkpoint(self): - """ - Seek if checkpoint exist, return None otherwise - """ - if self.params["training_params"]["load_epoch"] in ("best", "last"): - for filename in os.listdir(self.paths["checkpoints"]): - if self.params["training_params"]["load_epoch"] in filename: - return torch.load(os.path.join(self.paths["checkpoints"], filename)) - return None - - def load_existing_model(self, checkpoint, strict=True): - """ - Load information and weights from previous training - """ - self.load_save_info(checkpoint) - self.latest_epoch = checkpoint["epoch"] - if "step" in checkpoint: - self.latest_step = checkpoint["step"] - self.best = checkpoint["best"] - if "scaler_state_dict" in checkpoint: - self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) - # Load model weights from past training - for model_name in self.models.keys(): - self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(model_name)], strict=strict) - - def init_new_model(self): - """ - Initialize model - """ - # Specific weights initialization if exists - for model_name in self.models.keys(): - try: - self.models[model_name].init_weights() - except: - pass - - # Handle transfer learning instructions - if self.params["model_params"]["transfer_learning"]: - # Iterates over models - for model_name in self.params["model_params"]["transfer_learning"].keys(): - state_dict_name, path, learnable, strict = self.params["model_params"]["transfer_learning"][model_name] - # Loading pretrained weights file - checkpoint = torch.load(path) - try: - # Load pretrained weights for model - self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(state_dict_name)], strict=strict) - print("transfered weights for {}".format(state_dict_name), flush=True) - except RuntimeError as e: - print(e, flush=True) - # if error, try to load each parts of the model (useful if only few layers are different) - for key in checkpoint["{}_state_dict".format(state_dict_name)].keys(): - try: - # for pre-training of decision layer - if "end_conv" in key and "transfered_charset" in self.params["model_params"]: - self.adapt_decision_layer_to_old_charset(model_name, key, checkpoint, state_dict_name) - else: - self.models[model_name].load_state_dict( - {key: checkpoint["{}_state_dict".format(state_dict_name)][key]}, strict=False) - except RuntimeError as e: - ## exception when adding linebreak token from pretraining - print(e, flush=True) - # Set parameters no trainable - if not learnable: - self.set_model_learnable(self.models[model_name], False) - - def adapt_decision_layer_to_old_charset(self, model_name, key, checkpoint, state_dict_name): - """ - Transfer learning of the decision learning in case of close charsets between pre-training and training - """ - pretrained_chars = list() - weights = checkpoint["{}_state_dict".format(state_dict_name)][key] - new_size = list(weights.size()) - new_size[0] = len(self.dataset.charset) + self.params["model_params"]["additional_tokens"] - new_weights = torch.zeros(new_size, device=weights.device, dtype=weights.dtype) - old_charset = checkpoint["charset"] if "charset" in checkpoint else self.params["model_params"]["old_charset"] - if not "bias" in key: - kaiming_uniform_(new_weights, nonlinearity="relu") - for i, c in enumerate(self.dataset.charset): - if c in old_charset: - new_weights[i] = weights[old_charset.index(c)] - pretrained_chars.append(c) - if "transfered_charset_last_is_ctc_blank" in self.params["model_params"] and self.params["model_params"]["transfered_charset_last_is_ctc_blank"]: - new_weights[-1] = weights[-1] - pretrained_chars.append("<blank>") - checkpoint["{}_state_dict".format(state_dict_name)][key] = new_weights - self.models[model_name].load_state_dict({key: checkpoint["{}_state_dict".format(state_dict_name)][key]}, strict=False) - print("Pretrained chars for {} ({}): {}".format(key, len(pretrained_chars), pretrained_chars)) - - def load_optimizers(self, checkpoint, reset_optimizer=False): - """ - Load the optimizer of each model - """ - for model_name in self.models.keys(): - new_params = dict() - if checkpoint and "optimizer_named_params_{}".format(model_name) in checkpoint: - self.optimizers_named_params_by_group[model_name] = checkpoint["optimizer_named_params_{}".format(model_name)] - # for progressively growing models - for name, param in self.models[model_name].named_parameters(): - existing = False - for gr in self.optimizers_named_params_by_group[model_name]: - if name in gr: - gr[name] = param - existing = True - break - if not existing: - new_params.update({name: param}) - else: - self.optimizers_named_params_by_group[model_name] = [dict(), ] - self.optimizers_named_params_by_group[model_name][0].update(self.models[model_name].named_parameters()) - - # Instantiate optimizer - self.reset_optimizer(model_name) - - # Handle learning rate schedulers - if "lr_schedulers" in self.params["training_params"] and self.params["training_params"]["lr_schedulers"]: - key = "all" if "all" in self.params["training_params"]["lr_schedulers"] else model_name - if key in self.params["training_params"]["lr_schedulers"]: - self.lr_schedulers[model_name] = self.params["training_params"]["lr_schedulers"][key]["class"]\ - (self.optimizers[model_name], **self.params["training_params"]["lr_schedulers"][key]["args"]) - - # Load optimizer state from past training - if checkpoint and not reset_optimizer: - self.optimizers[model_name].load_state_dict(checkpoint["optimizer_{}_state_dict".format(model_name)]) - # Load optimizer scheduler config from past training if used - if "lr_schedulers" in self.params["training_params"] and self.params["training_params"]["lr_schedulers"] \ - and "lr_scheduler_{}_state_dict".format(model_name) in checkpoint.keys(): - self.lr_schedulers[model_name].load_state_dict(checkpoint["lr_scheduler_{}_state_dict".format(model_name)]) - - # for progressively growing models, keeping learning rate - if checkpoint and new_params: - self.optimizers_named_params_by_group[model_name].append(new_params) - self.optimizers[model_name].add_param_group({"params": list(new_params.values())}) - - @staticmethod - def set_model_learnable(model, learnable=True): - for p in list(model.parameters()): - p.requires_grad = learnable - - def save_model(self, epoch, name, keep_weights=False): - """ - Save model weights and training info for curriculum learning or learning rate for instance - """ - if not self.is_master: - return - to_del = [] - for filename in os.listdir(self.paths["checkpoints"]): - if name in filename: - to_del.append(os.path.join(self.paths["checkpoints"], filename)) - path = os.path.join(self.paths["checkpoints"], "{}_{}.pt".format(name, epoch)) - content = { - 'optimizers_named_params': self.optimizers_named_params_by_group, - 'epoch': epoch, - 'step': self.latest_step, - "scaler_state_dict": self.scaler.state_dict(), - 'best': self.best, - "charset": self.dataset.charset - } - for model_name in self.optimizers: - content['optimizer_{}_state_dict'.format(model_name)] = self.optimizers[model_name].state_dict() - for model_name in self.lr_schedulers: - content["lr_scheduler_{}_state_dict".format(model_name)] = self.lr_schedulers[model_name].state_dict() - content = self.add_save_info(content) - for model_name in self.models.keys(): - content["{}_state_dict".format(model_name)] = self.models[model_name].state_dict() - torch.save(content, path) - if not keep_weights: - for path_to_del in to_del: - if path_to_del != path: - os.remove(path_to_del) - - def reset_optimizers(self): - """ - Reset learning rate of all optimizers - """ - for model_name in self.models.keys(): - self.reset_optimizer(model_name) - - def reset_optimizer(self, model_name): - """ - Reset optimizer learning rate for given model - """ - params = list(self.optimizers_named_params_by_group[model_name][0].values()) - key = "all" if "all" in self.params["training_params"]["optimizers"] else model_name - self.optimizers[model_name] = self.params["training_params"]["optimizers"][key]["class"](params, **self.params["training_params"]["optimizers"][key]["args"]) - for i in range(1, len(self.optimizers_named_params_by_group[model_name])): - self.optimizers[model_name].add_param_group({"params": list(self.optimizers_named_params_by_group[model_name][i].values())}) - - def save_params(self): - """ - Output text file containing a summary of all hyperparameters chosen for the training - """ - def compute_nb_params(module): - return sum([np.prod(p.size()) for p in list(module.parameters())]) - - def class_to_str_dict(my_dict): - for key in my_dict.keys(): - if callable(my_dict[key]): - my_dict[key] = my_dict[key].__name__ - elif isinstance(my_dict[key], np.ndarray): - my_dict[key] = my_dict[key].tolist() - elif isinstance(my_dict[key], dict): - my_dict[key] = class_to_str_dict(my_dict[key]) - return my_dict - - path = os.path.join(self.paths["results"], "params") - if os.path.isfile(path): - return - params = copy.deepcopy(self.params) - params = class_to_str_dict(params) - params["date"] = date.today().strftime("%d/%m/%Y") - total_params = 0 - for model_name in self.models.keys(): - current_params = compute_nb_params(self.models[model_name]) - params["model_params"]["models"][model_name] = [params["model_params"]["models"][model_name], "{:,}".format(current_params)] - total_params += current_params - params["model_params"]["total_params"] = "{:,}".format(total_params) - - params["hardware"] = dict() - if self.device != "cpu": - for i in range(self.params["training_params"]["nb_gpu"]): - params["hardware"][str(i)] = "{} {}".format(torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i)) - else: - params["hardware"]["0"] = "CPU" - params["software"] = { - "python_version": sys.version, - "pytorch_version": torch.__version__, - "cuda_version": torch.version.cuda, - "cudnn_version": torch.backends.cudnn.version(), - } - with open(path, 'w') as f: - json.dump(params, f, indent=4) - - def backward_loss(self, loss, retain_graph=False): - self.scaler.scale(loss).backward(retain_graph=retain_graph) - - def step_optimizers(self, increment_step=True, names=None): - for model_name in self.optimizers: - if names and model_name not in names: - continue - if "gradient_clipping" in self.params["training_params"] and model_name in self.params["training_params"]["gradient_clipping"]["models"]: - self.scaler.unscale_(self.optimizers[model_name]) - torch.nn.utils.clip_grad_norm_(self.models[model_name].parameters(), self.params["training_params"]["gradient_clipping"]["max"]) - self.scaler.step(self.optimizers[model_name]) - self.scaler.update() - self.latest_step += 1 - - def zero_optimizers(self, set_to_none=True): - for model_name in self.optimizers: - self.zero_optimizer(model_name, set_to_none) - - def zero_optimizer(self, model_name, set_to_none=True): - self.optimizers[model_name].zero_grad(set_to_none=set_to_none) - - def train(self): - """ - Main training loop - """ - # init tensorboard file and output param summary file - if self.is_master: - self.writer = SummaryWriter(self.paths["results"]) - self.save_params() - # init variables - self.begin_time = time() - focus_metric_name = self.params["training_params"]["focus_metric"] - nb_epochs = self.params["training_params"]["max_nb_epochs"] - interval_save_weights = self.params["training_params"]["interval_save_weights"] - metric_names = self.params["training_params"]["train_metrics"] - - display_values = None - # init curriculum learning - if "curriculum_learning" in self.params["training_params"].keys() and self.params["training_params"]["curriculum_learning"]: - self.init_curriculum() - # perform epochs - for num_epoch in range(self.latest_epoch+1, nb_epochs): - self.dataset.train_dataset.training_info = { - "epoch": self.latest_epoch, - "step": self.latest_step - } - self.phase = "train" - # Check maximum training time stop condition - if self.params["training_params"]["max_training_time"] and time() - self.begin_time > self.params["training_params"]["max_training_time"]: - break - # set models trainable - for model_name in self.models.keys(): - self.models[model_name].train() - self.latest_epoch = num_epoch - if self.dataset.train_dataset.curriculum_config: - self.dataset.train_dataset.curriculum_config["epoch"] = self.latest_epoch - # init epoch metrics values - self.metric_manager["train"] = MetricManager(metric_names=metric_names, dataset_name=self.dataset_name) - - with tqdm(total=len(self.dataset.train_loader.dataset)) as pbar: - pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs)) - # iterates over mini-batch data - for ind_batch, batch_data in enumerate(self.dataset.train_loader): - self.latest_batch = ind_batch + 1 - self.total_batch += 1 - # train on batch data and compute metrics - batch_values = self.train_batch(batch_data, metric_names) - batch_metrics = self.metric_manager["train"].compute_metrics(batch_values, metric_names) - batch_metrics["names"] = batch_data["names"] - batch_metrics["ids"] = batch_data["ids"] - # Merge metrics if Distributed Data Parallel is used - if self.params["training_params"]["use_ddp"]: - batch_metrics = self.merge_ddp_metrics(batch_metrics) - # Update learning rate via scheduler if one is used - if self.params["training_params"]["lr_schedulers"]: - for model_name in self.models: - key = "all" if "all" in self.params["training_params"]["lr_schedulers"] else model_name - if model_name in self.lr_schedulers and ind_batch % self.params["training_params"]["lr_schedulers"][key]["step_interval"] == 0: - self.lr_schedulers[model_name].step(len(batch_metrics["names"])) - if "lr" in metric_names: - self.writer.add_scalar("lr_{}".format(model_name), self.lr_schedulers[model_name].lr, self.lr_schedulers[model_name].step_num) - # Update dropout scheduler if used - if self.dropout_scheduler: - self.dropout_scheduler.step(len(batch_metrics["names"])) - self.dropout_scheduler.update_dropout_rate() - - # Add batch metrics values to epoch metrics values - self.metric_manager["train"].update_metrics(batch_metrics) - display_values = self.metric_manager["train"].get_display_values() - pbar.set_postfix(values=str(display_values)) - pbar.update(len(batch_data["names"])) - - # log metrics in tensorboard file - if self.is_master: - for key in display_values.keys(): - self.writer.add_scalar('{}_{}'.format(self.params["dataset_params"]["train"]["name"], key), display_values[key], num_epoch) - self.latest_train_metrics = display_values - - # evaluate and compute metrics for valid sets - if self.params["training_params"]["eval_on_valid"] and num_epoch % self.params["training_params"]["eval_on_valid_interval"] == 0: - for valid_set_name in self.dataset.valid_loaders.keys(): - # evaluate set and compute metrics - eval_values = self.evaluate(valid_set_name) - self.latest_valid_metrics = eval_values - # log valid metrics in tensorboard file - if self.is_master: - for key in eval_values.keys(): - self.writer.add_scalar('{}_{}'.format(valid_set_name, key), eval_values[key], num_epoch) - if valid_set_name == self.params["training_params"]["set_name_focus_metric"] and (self.best is None or \ - (eval_values[focus_metric_name] <= self.best and self.params["training_params"]["expected_metric_value"] == "low") or\ - (eval_values[focus_metric_name] >= self.best and self.params["training_params"]["expected_metric_value"] == "high")): - self.save_model(epoch=num_epoch, name="best") - self.best = eval_values[focus_metric_name] - - # Handle curriculum learning update - if self.dataset.train_dataset.curriculum_config: - self.check_and_update_curriculum() - - if "curriculum_model" in self.params["model_params"] and self.params["model_params"]["curriculum_model"]: - self.update_curriculum_model() - - # save model weights - if self.is_master: - self.save_model(epoch=num_epoch, name="last") - if interval_save_weights and num_epoch % interval_save_weights == 0: - self.save_model(epoch=num_epoch, name="weigths", keep_weights=True) - self.writer.flush() - - def evaluate(self, set_name, **kwargs): - """ - Main loop for validation - """ - self.phase = "eval" - loader = self.dataset.valid_loaders[set_name] - # Set models in eval mode - for model_name in self.models.keys(): - self.models[model_name].eval() - metric_names = self.params["training_params"]["eval_metrics"] - display_values = None - - # initialize epoch metrics - self.metric_manager[set_name] = MetricManager(metric_names, dataset_name=self.dataset_name) - with tqdm(total=len(loader.dataset)) as pbar: - pbar.set_description("Evaluation E{}".format(self.latest_epoch)) - with torch.no_grad(): - # iterate over batch data - for ind_batch, batch_data in enumerate(loader): - self.latest_batch = ind_batch + 1 - # eval batch data and compute metrics - batch_values = self.evaluate_batch(batch_data, metric_names) - batch_metrics = self.metric_manager[set_name].compute_metrics(batch_values, metric_names) - batch_metrics["names"] = batch_data["names"] - batch_metrics["ids"] = batch_data["ids"] - # merge metrics values if Distributed Data Parallel is used - if self.params["training_params"]["use_ddp"]: - batch_metrics = self.merge_ddp_metrics(batch_metrics) - - # add batch metrics to epoch metrics - self.metric_manager[set_name].update_metrics(batch_metrics) - display_values = self.metric_manager[set_name].get_display_values() - - pbar.set_postfix(values=str(display_values)) - pbar.update(len(batch_data["names"])) - if "cer_by_nb_cols" in metric_names: - self.log_cer_by_nb_cols(set_name) - return display_values - - def predict(self, custom_name, sets_list, metric_names, output=False): - """ - Main loop for evaluation - """ - self.phase = "predict" - metric_names = metric_names.copy() - self.dataset.generate_test_loader(custom_name, sets_list) - loader = self.dataset.test_loaders[custom_name] - # Set models in eval mode - for model_name in self.models.keys(): - self.models[model_name].eval() - - # initialize epoch metrics - self.metric_manager[custom_name] = MetricManager(metric_names, self.dataset_name) - - with tqdm(total=len(loader.dataset)) as pbar: - pbar.set_description("Prediction") - with torch.no_grad(): - for ind_batch, batch_data in enumerate(loader): - # iterates over batch data - self.latest_batch = ind_batch + 1 - # eval batch data and compute metrics - batch_values = self.evaluate_batch(batch_data, metric_names) - batch_metrics = self.metric_manager[custom_name].compute_metrics(batch_values, metric_names) - batch_metrics["names"] = batch_data["names"] - batch_metrics["ids"] = batch_data["ids"] - # merge batch metrics if Distributed Data Parallel is used - if self.params["training_params"]["use_ddp"]: - batch_metrics = self.merge_ddp_metrics(batch_metrics) - - # add batch metrics to epoch metrics - self.metric_manager[custom_name].update_metrics(batch_metrics) - display_values = self.metric_manager[custom_name].get_display_values() - - pbar.set_postfix(values=str(display_values)) - pbar.update(len(batch_data["names"])) - - self.dataset.remove_test_dataset(custom_name) - # output metrics values if requested - if output: - if "pred" in metric_names: - self.output_pred(custom_name) - metrics = self.metric_manager[custom_name].get_display_values(output=True) - path = os.path.join(self.paths["results"], "predict_{}_{}.txt".format(custom_name, self.latest_epoch)) - with open(path, "w") as f: - for metric_name in metrics.keys(): - f.write("{}: {}\n".format(metric_name, metrics[metric_name])) - - def output_pred(self, name): - path = os.path.join(self.paths["results"], "pred_{}_{}.txt".format(name, self.latest_epoch)) - pred = "\n".join(self.metric_manager[name].get("pred")) - with open(path, "w") as f: - f.write(pred) - - def launch_ddp(self): - """ - Initialize Distributed Data Parallel system - """ - mp.set_start_method('fork', force=True) - os.environ['MASTER_ADDR'] = self.ddp_config["address"] - os.environ['MASTER_PORT'] = str(self.ddp_config["port"]) - dist.init_process_group(self.ddp_config["backend"], rank=self.ddp_config["rank"], world_size=self.params["training_params"]["nb_gpu"]) - torch.cuda.set_device(self.ddp_config["rank"]) - random.seed(self.manual_seed) - np.random.seed(self.manual_seed) - torch.manual_seed(self.manual_seed) - torch.cuda.manual_seed(self.manual_seed) - - def merge_ddp_metrics(self, metrics): - """ - Merge metrics when Distributed Data Parallel is used - """ - for metric_name in metrics.keys(): - if metric_name in ["edit_words", "nb_words", "edit_chars", "nb_chars", "edit_chars_force_len", - "edit_chars_curr", "nb_chars_curr", "ids"]: - metrics[metric_name] = self.cat_ddp_metric(metrics[metric_name]) - elif metric_name in ["nb_samples", "loss", "loss_ce", "loss_ctc", "loss_ce_end"]: - metrics[metric_name] = self.sum_ddp_metric(metrics[metric_name], average=False) - return metrics - - def sum_ddp_metric(self, metric, average=False): - """ - Sum metrics for Distributed Data Parallel - """ - sum = torch.tensor(metric[0]).to(self.device) - dist.all_reduce(sum, op=dist.ReduceOp.SUM) - if average: - sum.true_divide(dist.get_world_size()) - return [sum.item(), ] - - def cat_ddp_metric(self, metric): - """ - Concatenate metrics for Distributed Data Parallel - """ - tensor = torch.tensor(metric).unsqueeze(0).to(self.device) - res = [torch.zeros(tensor.size()).long().to(self.device) for _ in range(dist.get_world_size())] - dist.all_gather(res, tensor) - return list(torch.cat(res, dim=0).flatten().cpu().numpy()) - - @staticmethod - def cleanup(): - dist.destroy_process_group() - - def train_batch(self, batch_data, metric_names): - raise NotImplementedError - - def evaluate_batch(self, batch_data, metric_names): - raise NotImplementedError - - def init_curriculum(self): - raise NotImplementedError - - def update_curriculum(self): - raise NotImplementedError - - def add_checkpoint_info(self, load_mode="last", **kwargs): - for filename in os.listdir(self.paths["checkpoints"]): - if load_mode in filename: - checkpoint_path = os.path.join(self.paths["checkpoints"], filename) - checkpoint = torch.load(checkpoint_path) - for key in kwargs.keys(): - checkpoint[key] = kwargs[key] - torch.save(checkpoint, checkpoint_path) - return - self.save_model(self.latest_epoch, "last") - - def load_save_info(self, info_dict): - """ - Load curriculum info from saved model info - """ - if "curriculum_config" in info_dict.keys(): - self.dataset.train_dataset.curriculum_config = info_dict["curriculum_config"] - - def add_save_info(self, info_dict): - """ - Add curriculum info to model info to be saved - """ - info_dict["curriculum_config"] = self.dataset.train_dataset.curriculum_config - return info_dict \ No newline at end of file diff --git a/basic/metric_manager.py b/basic/metric_manager.py deleted file mode 100644 index 6c8576863a2d26a85c2d5dbc4321323dedf870a0..0000000000000000000000000000000000000000 --- a/basic/metric_manager.py +++ /dev/null @@ -1,538 +0,0 @@ - -from Datasets.dataset_formatters.rimes_formatter import SEM_MATCHING_TOKENS as RIMES_MATCHING_TOKENS -from Datasets.dataset_formatters.read2016_formatter import SEM_MATCHING_TOKENS as READ_MATCHING_TOKENS -from Datasets.dataset_formatters.simara_formatter import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS -import re -import networkx as nx -import editdistance -import numpy as np -from dan.post_processing import PostProcessingModuleREAD, PostProcessingModuleRIMES, PostProcessingModuleSIMARA - - -class MetricManager: - - def __init__(self, metric_names, dataset_name): - self.dataset_name = dataset_name - if "READ" in dataset_name and "page" in dataset_name: - self.post_processing_module = PostProcessingModuleREAD - self.matching_tokens = READ_MATCHING_TOKENS - self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_read - elif "RIMES" in dataset_name and "page" in dataset_name: - self.post_processing_module = PostProcessingModuleRIMES - self.matching_tokens = RIMES_MATCHING_TOKENS - self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_rimes - elif "simara" in dataset_name and "page" in dataset_name: - self.post_processing_module = PostProcessingModuleSIMARA - self.matching_tokens = SIMARA_MATCHING_TOKENS - self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_simara - else: - self.matching_tokens = dict() - - self.layout_tokens = "".join(list(self.matching_tokens.keys()) + list(self.matching_tokens.values())) - if len(self.layout_tokens) == 0: - self.layout_tokens = None - self.metric_names = metric_names - self.epoch_metrics = None - - self.linked_metrics = { - "cer": ["edit_chars", "nb_chars"], - "wer": ["edit_words", "nb_words"], - "loer": ["edit_graph", "nb_nodes_and_edges", "nb_pp_op_layout", "nb_gt_layout_token"], - "precision": ["precision", "weights"], - "map_cer_per_class": ["map_cer", ], - "layout_precision_per_class_per_threshold": ["map_cer", ], - } - - self.init_metrics() - - def init_metrics(self): - """ - Initialization of the metrics specified in metrics_name - """ - self.epoch_metrics = { - "nb_samples": list(), - "names": list(), - "ids": list(), - } - - for metric_name in self.metric_names: - if metric_name in self.linked_metrics: - for linked_metric_name in self.linked_metrics[metric_name]: - if linked_metric_name not in self.epoch_metrics.keys(): - self.epoch_metrics[linked_metric_name] = list() - else: - self.epoch_metrics[metric_name] = list() - - def update_metrics(self, batch_metrics): - """ - Add batch metrics to the metrics - """ - for key in batch_metrics.keys(): - if key in self.epoch_metrics: - self.epoch_metrics[key] += batch_metrics[key] - - def get_display_values(self, output=False): - """ - format metrics values for shell display purposes - """ - metric_names = self.metric_names.copy() - if output: - metric_names.extend(["nb_samples"]) - display_values = dict() - for metric_name in metric_names: - value = None - if output: - if metric_name in ["nb_samples", "weights"]: - value = np.sum(self.epoch_metrics[metric_name]) - elif metric_name in ["time", ]: - total_time = np.sum(self.epoch_metrics[metric_name]) - sample_time = total_time / np.sum(self.epoch_metrics["nb_samples"]) - display_values["sample_time"] = round(sample_time, 4) - value = total_time - elif metric_name == "loer": - display_values["pper"] = round(np.sum(self.epoch_metrics["nb_pp_op_layout"]) / np.sum(self.epoch_metrics["nb_gt_layout_token"]), 4) - elif metric_name == "map_cer_per_class": - value = compute_global_mAP_per_class(self.epoch_metrics["map_cer"]) - for key in value.keys(): - display_values["map_cer_" + key] = round(value[key], 4) - continue - elif metric_name == "layout_precision_per_class_per_threshold": - value = compute_global_precision_per_class_per_threshold(self.epoch_metrics["map_cer"]) - for key_class in value.keys(): - for threshold in value[key_class].keys(): - display_values["map_cer_{}_{}".format(key_class, threshold)] = round( - value[key_class][threshold], 4) - continue - if metric_name == "cer": - value = np.sum(self.epoch_metrics["edit_chars"]) / np.sum(self.epoch_metrics["nb_chars"]) - if output: - display_values["nb_chars"] = np.sum(self.epoch_metrics["nb_chars"]) - elif metric_name == "wer": - value = np.sum(self.epoch_metrics["edit_words"]) / np.sum(self.epoch_metrics["nb_words"]) - if output: - display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"]) - elif metric_name in ["loss", "loss_ctc", "loss_ce", "syn_max_lines"]: - value = np.average(self.epoch_metrics[metric_name], weights=np.array(self.epoch_metrics["nb_samples"])) - elif metric_name == "map_cer": - value = compute_global_mAP(self.epoch_metrics[metric_name]) - elif metric_name == "loer": - value = np.sum(self.epoch_metrics["edit_graph"]) / np.sum(self.epoch_metrics["nb_nodes_and_edges"]) - elif value is None: - continue - - display_values[metric_name] = round(value, 4) - return display_values - - def compute_metrics(self, values, metric_names): - metrics = { - "nb_samples": [values["nb_samples"], ], - } - for v in ["weights", "time"]: - if v in values: - metrics[v] = [values[v]] - for metric_name in metric_names: - if metric_name == "cer": - metrics["edit_chars"] = [edit_cer_from_string(u, v, self.layout_tokens) for u, v in zip(values["str_y"], values["str_x"])] - metrics["nb_chars"] = [nb_chars_cer_from_string(gt, self.layout_tokens) for gt in values["str_y"]] - elif metric_name == "wer": - split_gt = [format_string_for_wer(gt, self.layout_tokens) for gt in values["str_y"]] - split_pred = [format_string_for_wer(pred, self.layout_tokens) for pred in values["str_x"]] - metrics["edit_words"] = [edit_wer_from_formatted_split_text(gt, pred) for (gt, pred) in zip(split_gt, split_pred)] - metrics["nb_words"] = [len(gt) for gt in split_gt] - elif metric_name in ["loss_ctc", "loss_ce", "loss", "syn_max_lines", ]: - metrics[metric_name] = [values[metric_name], ] - elif metric_name == "map_cer": - pp_pred = list() - pp_score = list() - for pred, score in zip(values["str_x"], values["confidence_score"]): - pred_score = self.post_processing_module().post_process(pred, score) - pp_pred.append(pred_score[0]) - pp_score.append(pred_score[1]) - metrics[metric_name] = [compute_layout_mAP_per_class(y, x, conf, self.matching_tokens) for x, conf, y in zip(pp_pred, pp_score, values["str_y"])] - elif metric_name == "loer": - pp_pred = list() - metrics["nb_pp_op_layout"] = list() - for pred in values["str_x"]: - pp_module = self.post_processing_module() - pp_pred.append(pp_module.post_process(pred)) - metrics["nb_pp_op_layout"].append(pp_module.num_op) - metrics["nb_gt_layout_token"] = [len(keep_only_tokens(str_x, self.layout_tokens)) for str_x in values["str_x"]] - edit_and_num_items = [self.edit_and_num_edge_nodes(y, x) for x, y in zip(pp_pred, values["str_y"])] - metrics["edit_graph"], metrics["nb_nodes_and_edges"] = [ei[0] for ei in edit_and_num_items], [ei[1] for ei in edit_and_num_items] - return metrics - - def get(self, name): - return self.epoch_metrics[name] - - -def keep_only_tokens(str, tokens): - """ - Remove all but layout tokens from string - """ - return re.sub('([^' + tokens + '])', '', str) - - -def keep_all_but_tokens(str, tokens): - """ - Remove all layout tokens from string - """ - return re.sub('([' + tokens + '])', '', str) - - -def edit_cer_from_string(gt, pred, layout_tokens=None): - """ - Format and compute edit distance between two strings at character level - """ - gt = format_string_for_cer(gt, layout_tokens) - pred = format_string_for_cer(pred, layout_tokens) - return editdistance.eval(gt, pred) - - -def nb_chars_cer_from_string(gt, layout_tokens=None): - """ - Compute length after formatting of ground truth string - """ - return len(format_string_for_cer(gt, layout_tokens)) - - -def edit_wer_from_string(gt, pred, layout_tokens=None): - """ - Format and compute edit distance between two strings at word level - """ - split_gt = format_string_for_wer(gt, layout_tokens) - split_pred = format_string_for_wer(pred, layout_tokens) - return edit_wer_from_formatted_split_text(split_gt, split_pred) - - -def format_string_for_wer(str, layout_tokens): - """ - Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space - """ - str = re.sub('([\[\]{}/\\()\"\'&+*=<>?.;:,!\-—_€#%°])', r' \1 ', str) # punctuation processed as word - if layout_tokens is not None: - str = keep_all_but_tokens(str, layout_tokens) # remove layout tokens from metric - str = re.sub('([ \n])+', " ", str).strip() # keep only one space character - return str.split(" ") - - -def format_string_for_cer(str, layout_tokens): - """ - Format string for CER computation: remove layout tokens and extra spaces - """ - if layout_tokens is not None: - str = keep_all_but_tokens(str, layout_tokens) # remove layout tokens from metric - str = re.sub('([\n])+', "\n", str) # remove consecutive line breaks - str = re.sub('([ ])+', " ", str).strip() # remove consecutive spaces - return str - - -def edit_wer_from_formatted_split_text(gt, pred): - """ - Compute edit distance at word level from formatted string as list - """ - return editdistance.eval(gt, pred) - - -def extract_by_tokens(input_str, begin_token, end_token, associated_score=None, order_by_score=False): - """ - Extract list of text regions by begin and end tokens - Order the list by confidence score - """ - if order_by_score: - assert associated_score is not None - res = list() - for match in re.finditer("{}[^{}]*{}".format(begin_token, end_token, end_token), input_str): - begin, end = match.regs[0] - if order_by_score: - res.append({ - "confidence": np.mean([associated_score[begin], associated_score[end-1]]), - "content": input_str[begin+1:end-1] - }) - else: - res.append(input_str[begin+1:end-1]) - if order_by_score: - res = sorted(res, key=lambda x: x["confidence"], reverse=True) - res = [r["content"] for r in res] - return res - - -def compute_layout_precision_per_threshold(gt, pred, score, begin_token, end_token, layout_tokens, return_weight=True): - """ - Compute average precision of a given class for CER threshold from 5% to 50% with a step of 5% - """ - pred_list = extract_by_tokens(pred, begin_token, end_token, associated_score=score, order_by_score=True) - gt_list = extract_by_tokens(gt, begin_token, end_token) - pred_list = [keep_all_but_tokens(p, layout_tokens) for p in pred_list] - gt_list = [keep_all_but_tokens(gt, layout_tokens) for gt in gt_list] - precision_per_threshold = [compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold/100) for threshold in range(5, 51, 5)] - if return_weight: - return precision_per_threshold, len(gt_list) - return precision_per_threshold - - -def compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold): - """ - Compute average precision of a given class for a given CER threshold - """ - remaining_gt_list = gt_list.copy() - num_true = len(gt_list) - correct = np.zeros((len(pred_list)), dtype=np.bool) - for i, pred in enumerate(pred_list): - if len(remaining_gt_list) == 0: - break - cer_with_gt = [edit_cer_from_string(gt, pred)/nb_chars_cer_from_string(gt) for gt in remaining_gt_list] - cer, ind = np.min(cer_with_gt), np.argmin(cer_with_gt) - if cer <= threshold: - correct[i] = True - del remaining_gt_list[ind] - precision = np.cumsum(correct, dtype=np.int) / np.arange(1, len(pred_list)+1) - recall = np.cumsum(correct, dtype=np.int) / num_true - max_precision_from_recall = np.maximum.accumulate(precision[::-1])[::-1] - recall_diff = (recall - np.concatenate([np.array([0, ]), recall[:-1]])) - P = np.sum(recall_diff * max_precision_from_recall) - return P - - -def compute_layout_mAP_per_class(gt, pred, score, tokens): - """ - Compute the mAP_cer for each class for a given sample - """ - layout_tokens = "".join(list(tokens.keys())) - AP_per_class = dict() - for token in tokens.keys(): - if token in gt: - AP_per_class[token] = compute_layout_precision_per_threshold(gt, pred, score, token, tokens[token], layout_tokens=layout_tokens) - return AP_per_class - - -def compute_global_mAP(list_AP_per_class): - """ - Compute the global mAP_cer for several samples - """ - weights_per_doc = list() - mAP_per_doc = list() - for doc_AP_per_class in list_AP_per_class: - APs = np.array([np.mean(doc_AP_per_class[key][0]) for key in doc_AP_per_class.keys()]) - weights = np.array([doc_AP_per_class[key][1] for key in doc_AP_per_class.keys()]) - if np.sum(weights) == 0: - mAP_per_doc.append(0) - else: - mAP_per_doc.append(np.average(APs, weights=weights)) - weights_per_doc.append(np.sum(weights)) - if np.sum(weights_per_doc) == 0: - return 0 - return np.average(mAP_per_doc, weights=weights_per_doc) - - -def compute_global_mAP_per_class(list_AP_per_class): - """ - Compute the mAP_cer per class for several samples - """ - mAP_per_class = dict() - for doc_AP_per_class in list_AP_per_class: - for key in doc_AP_per_class.keys(): - if key not in mAP_per_class: - mAP_per_class[key] = { - "AP": list(), - "weights": list() - } - mAP_per_class[key]["AP"].append(np.mean(doc_AP_per_class[key][0])) - mAP_per_class[key]["weights"].append(doc_AP_per_class[key][1]) - for key in mAP_per_class.keys(): - mAP_per_class[key] = np.average(mAP_per_class[key]["AP"], weights=mAP_per_class[key]["weights"]) - return mAP_per_class - - -def compute_global_precision_per_class_per_threshold(list_AP_per_class): - """ - Compute the mAP_cer per class and per threshold for several samples - """ - mAP_per_class = dict() - for doc_AP_per_class in list_AP_per_class: - for key in doc_AP_per_class.keys(): - if key not in mAP_per_class: - mAP_per_class[key] = dict() - for threshold in range(5, 51, 5): - mAP_per_class[key][threshold] = { - "precision": list(), - "weights": list() - } - for i, threshold in enumerate(range(5, 51, 5)): - mAP_per_class[key][threshold]["precision"].append(np.mean(doc_AP_per_class[key][0][i])) - mAP_per_class[key][threshold]["weights"].append(doc_AP_per_class[key][1]) - for key_class in mAP_per_class.keys(): - for threshold in mAP_per_class[key_class]: - mAP_per_class[key_class][threshold] = np.average(mAP_per_class[key_class][threshold]["precision"], weights=mAP_per_class[key_class][threshold]["weights"]) - return mAP_per_class - - -def str_to_graph_read(str): - """ - Compute graph from string of layout tokens for the READ 2016 dataset at single-page and double-page levels - """ - begin_layout_tokens = "".join(list(READ_MATCHING_TOKENS.keys())) - layout_token_sequence = keep_only_tokens(str, begin_layout_tokens) - g = nx.DiGraph() - g.add_node("D", type="document", level=4, page=0) - num = { - "ⓟ": 0, - "â“": 0, - "â“‘": 0, - "â“": 0, - "â“¢": 0 - } - previous_top_level_node = None - previous_middle_level_node = None - previous_low_level_node = None - for ind, c in enumerate(layout_token_sequence): - num[c] += 1 - if c == "ⓟ": - node_name = "P_{}".format(num[c]) - g.add_node(node_name, type="page", level=3, page=num["ⓟ"]) - g.add_edge("D", node_name) - if previous_top_level_node: - g.add_edge(previous_top_level_node, node_name) - previous_top_level_node = node_name - previous_middle_level_node = None - previous_low_level_node = None - if c in "â“â“¢": - node_name = "{}_{}".format("N" if c == "â“" else "S", num[c]) - g.add_node(node_name, type="number" if c == "â“" else "section", level=2, page=num["ⓟ"]) - g.add_edge(previous_top_level_node, node_name) - if previous_middle_level_node: - g.add_edge(previous_middle_level_node, node_name) - previous_middle_level_node = node_name - previous_low_level_node = None - if c in "â“â“‘": - node_name = "{}_{}".format("A" if c == "â“" else "B", num[c]) - g.add_node(node_name, type="annotation" if c == "â“" else "body", level=1, page=num["ⓟ"]) - g.add_edge(previous_middle_level_node, node_name) - if previous_low_level_node: - g.add_edge(previous_low_level_node, node_name) - previous_low_level_node = node_name - return g - - -def str_to_graph_rimes(str): - """ - Compute graph from string of layout tokens for the RIMES dataset at page level - """ - begin_layout_tokens = "".join(list(RIMES_MATCHING_TOKENS.keys())) - layout_token_sequence = keep_only_tokens(str, begin_layout_tokens) - g = nx.DiGraph() - g.add_node("D", type="document", level=2, page=0) - token_name_dict = { - "â“‘": "B", - "ⓞ": "O", - "â“¡": "R", - "â“¢": "S", - "ⓦ": "W", - "ⓨ": "Y", - "ⓟ": "P" - } - num = dict() - previous_node = None - for token in begin_layout_tokens: - num[token] = 0 - for ind, c in enumerate(layout_token_sequence): - num[c] += 1 - node_name = "{}_{}".format(token_name_dict[c], num[c]) - g.add_node(node_name, type=token_name_dict[c], level=1, page=0) - g.add_edge("D", node_name) - if previous_node: - g.add_edge(previous_node, node_name) - previous_node = node_name - return g - - -def str_to_graph_simara(str): - """ - Compute graph from string of layout tokens for the SIMARA dataset at page level - """ - begin_layout_tokens = "".join(list(SIMARA_MATCHING_TOKENS.keys())) - layout_token_sequence = keep_only_tokens(str, begin_layout_tokens) - g = nx.DiGraph() - g.add_node("D", type="document", level=2, page=0) - token_name_dict = { - "ⓘ": "I", - "â““": "D", - "â“¢": "S", - "â“’": "C", - "ⓟ": "P", - "â“": "A" - } - num = dict() - previous_node = None - for token in begin_layout_tokens: - num[token] = 0 - for ind, c in enumerate(layout_token_sequence): - num[c] += 1 - node_name = "{}_{}".format(token_name_dict[c], num[c]) - g.add_node(node_name, type=token_name_dict[c], level=1, page=0) - g.add_edge("D", node_name) - if previous_node: - g.add_edge(previous_node, node_name) - previous_node = node_name - return g - - -def graph_edit_distance_by_page_read(g1, g2): - """ - Compute graph edit distance page by page for the READ 2016 dataset - """ - num_pages_g1 = len([n for n in g1.nodes().items() if n[1]["level"] == 3]) - num_pages_g2 = len([n for n in g2.nodes().items() if n[1]["level"] == 3]) - page_graphs_1 = [g1.subgraph([n[0] for n in g1.nodes().items() if n[1]["page"] == num_page]) for num_page in range(1, num_pages_g1+1)] - page_graphs_2 = [g2.subgraph([n[0] for n in g2.nodes().items() if n[1]["page"] == num_page]) for num_page in range(1, num_pages_g2+1)] - edit = 0 - for i in range(max(len(page_graphs_1), len(page_graphs_2))): - page_1 = page_graphs_1[i] if i < len(page_graphs_1) else nx.DiGraph() - page_2 = page_graphs_2[i] if i < len(page_graphs_2) else nx.DiGraph() - edit += graph_edit_distance(page_1, page_2) - return edit - - -def graph_edit_distance(g1, g2): - """ - Compute graph edit distance between two graphs - """ - for v in nx.optimize_graph_edit_distance(g1, g2, - node_ins_cost=lambda node: 1, - node_del_cost=lambda node: 1, - node_subst_cost=lambda node1, node2: 0 if node1["type"] == node2["type"] else 1, - edge_ins_cost=lambda edge: 1, - edge_del_cost=lambda edge: 1, - edge_subst_cost=lambda edge1, edge2: 0 if edge1 == edge2 else 1 - ): - new_edit = v - return new_edit - - -def edit_and_num_items_for_ged_from_str_read(str_gt, str_pred): - """ - Compute graph edit distance and num nodes/edges for normalized graph edit distance - For the READ 2016 dataset - """ - g_gt = str_to_graph_read(str_gt) - g_pred = str_to_graph_read(str_pred) - return graph_edit_distance_by_page_read(g_gt, g_pred), g_gt.number_of_nodes() + g_gt.number_of_edges() - - -def edit_and_num_items_for_ged_from_str_rimes(str_gt, str_pred): - """ - Compute graph edit distance and num nodes/edges for normalized graph edit distance - For the RIMES dataset - """ - g_gt = str_to_graph_rimes(str_gt) - g_pred = str_to_graph_rimes(str_pred) - return graph_edit_distance(g_gt, g_pred), g_gt.number_of_nodes() + g_gt.number_of_edges() - - -def edit_and_num_items_for_ged_from_str_simara(str_gt, str_pred): - """ - Compute graph edit distance and num nodes/edges for normalized graph edit distance - For the SIMARA dataset - """ - g_gt = str_to_graph_simara(str_gt) - g_pred = str_to_graph_simara(str_pred) - return graph_edit_distance(g_gt, g_pred), g_gt.number_of_nodes() + g_gt.number_of_edges() diff --git a/basic/scheduler.py b/basic/scheduler.py deleted file mode 100644 index 6c875c1da2f57ab0306635cbc77309c2b83af116..0000000000000000000000000000000000000000 --- a/basic/scheduler.py +++ /dev/null @@ -1,51 +0,0 @@ - -from torch.nn import Dropout, Dropout2d -import numpy as np - - -class DropoutScheduler: - - def __init__(self, models, function, T=1e5): - """ - T: number of gradient updates to converge - """ - - self.teta_list = list() - self.init_teta_list(models) - self.function = function - self.T = T - self.step_num = 0 - - def step(self): - self.step(1) - - def step(self, num): - self.step_num += num - - def init_teta_list(self, models): - for model_name in models.keys(): - self.init_teta_list_module(models[model_name]) - - def init_teta_list_module(self, module): - for child in module.children(): - if isinstance(child, Dropout) or isinstance(child, Dropout2d): - self.teta_list.append([child, child.p]) - else: - self.init_teta_list_module(child) - - def update_dropout_rate(self): - for (module, p) in self.teta_list: - module.p = self.function(p, self.step_num, self.T) - - -def exponential_dropout_scheduler(dropout_rate, step, max_step): - return dropout_rate * (1 - np.exp(-10 * step / max_step)) - - -def exponential_scheduler(init_value, end_value, step, max_step): - step = min(step, max_step-1) - return init_value - (init_value - end_value) * (1 - np.exp(-10*step/max_step)) - - -def linear_scheduler(init_value, end_value, step, max_step): - return init_value + step * (end_value - init_value) / max_step \ No newline at end of file diff --git a/basic/utils.py b/basic/utils.py deleted file mode 100644 index ac50e57ab1136a2f2e1d209c39271406316d0479..0000000000000000000000000000000000000000 --- a/basic/utils.py +++ /dev/null @@ -1,179 +0,0 @@ - - -import numpy as np -import torch -from torch.distributions.uniform import Uniform -import cv2 - - -def randint(low, high): - """ - call torch.randint to preserve random among dataloader workers - """ - return int(torch.randint(low, high, (1, ))) - - -def rand(): - """ - call torch.rand to preserve random among dataloader workers - """ - return float(torch.rand((1, ))) - - -def rand_uniform(low, high): - """ - call torch uniform to preserve random among dataloader workers - """ - return float(Uniform(low, high).sample()) - - -def pad_sequences_1D(data, padding_value): - """ - Pad data with padding_value to get same length - """ - x_lengths = [len(x) for x in data] - longest_x = max(x_lengths) - padded_data = np.ones((len(data), longest_x)).astype(np.int32) * padding_value - for i, x_len in enumerate(x_lengths): - padded_data[i, :x_len] = data[i][:x_len] - return padded_data - - -def resize_max(img, max_width=None, max_height=None): - if max_width is not None and img.shape[1] > max_width: - ratio = max_width / img.shape[1] - new_h = int(np.floor(ratio * img.shape[0])) - new_w = int(np.floor(ratio * img.shape[1])) - img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) - if max_height is not None and img.shape[0] > max_height: - ratio = max_height / img.shape[0] - new_h = int(np.floor(ratio * img.shape[0])) - new_w = int(np.floor(ratio * img.shape[1])) - img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) - return img - - -def pad_images(data, padding_value, padding_mode="br"): - """ - data: list of numpy array - mode: "br"/"tl"/"random" (bottom-right, top-left, random) - """ - x_lengths = [x.shape[0] for x in data] - y_lengths = [x.shape[1] for x in data] - longest_x = max(x_lengths) - longest_y = max(y_lengths) - padded_data = np.ones((len(data), longest_x, longest_y, data[0].shape[2])) * padding_value - for i, xy_len in enumerate(zip(x_lengths, y_lengths)): - x_len, y_len = xy_len - if padding_mode == "br": - padded_data[i, :x_len, :y_len, ...] = data[i] - elif padding_mode == "tl": - padded_data[i, -x_len:, -y_len:, ...] = data[i] - elif padding_mode == "random": - xmax = longest_x - x_len - ymax = longest_y - y_len - xi = randint(0, xmax) if xmax >= 1 else 0 - yi = randint(0, ymax) if ymax >= 1 else 0 - padded_data[i, xi:xi+x_len, yi:yi+y_len, ...] = data[i] - else: - raise NotImplementedError("Undefined padding mode: {}".format(padding_mode)) - return padded_data - - -def pad_image(image, padding_value, new_height=None, new_width=None, pad_width=None, pad_height=None, padding_mode="br", return_position=False): - """ - data: list of numpy array - mode: "br"/"tl"/"random" (bottom-right, top-left, random) - """ - if pad_width is not None and new_width is not None: - raise NotImplementedError("pad_with and new_width are not compatible") - if pad_height is not None and new_height is not None: - raise NotImplementedError("pad_height and new_height are not compatible") - - h, w, c = image.shape - pad_width = pad_width if pad_width is not None else max(0, new_width - w) if new_width is not None else 0 - pad_height = pad_height if pad_height is not None else max(0, new_height - h) if new_height is not None else 0 - - if not (pad_width == 0 and pad_height == 0): - padded_image = np.ones((h+pad_height, w+pad_width, c)) * padding_value - if padding_mode == "br": - hi, wi = 0, 0 - elif padding_mode == "tl": - hi, wi = pad_height, pad_width - elif padding_mode == "random": - hi = randint(0, pad_height) if pad_height >= 1 else 0 - wi = randint(0, pad_width) if pad_width >= 1 else 0 - else: - raise NotImplementedError("Undefined padding mode: {}".format(padding_mode)) - padded_image[hi:hi + h, wi:wi + w, ...] = image - output = padded_image - else: - hi, wi = 0, 0 - output = image - - if return_position: - return output, [[hi, hi+h], [wi, wi+w]] - return output - - -def pad_image_width_right(img, new_width, padding_value): - """ - Pad img to right side with padding value to reach new_width as width - """ - h, w, c = img.shape - pad_width = max((new_width - w), 0) - pad_right = np.ones((h, pad_width, c), dtype=img.dtype) * padding_value - img = np.concatenate([img, pad_right], axis=1) - return img - - -def pad_image_width_left(img, new_width, padding_value): - """ - Pad img to left side with padding value to reach new_width as width - """ - h, w, c = img.shape - pad_width = max((new_width - w), 0) - pad_left = np.ones((h, pad_width, c), dtype=img.dtype) * padding_value - img = np.concatenate([pad_left, img], axis=1) - return img - - -def pad_image_width_random(img, new_width, padding_value, max_pad_left_ratio=1): - """ - Randomly pad img to left and right sides with padding value to reach new_width as width - """ - h, w, c = img.shape - pad_width = max((new_width - w), 0) - max_pad_left = int(max_pad_left_ratio*pad_width) - pad_left = randint(0, min(pad_width, max_pad_left)) if pad_width != 0 and max_pad_left > 0 else 0 - pad_right = pad_width - pad_left - pad_left = np.ones((h, pad_left, c), dtype=img.dtype) * padding_value - pad_right = np.ones((h, pad_right, c), dtype=img.dtype) * padding_value - img = np.concatenate([pad_left, img, pad_right], axis=1) - return img - - -def pad_image_height_random(img, new_height, padding_value, max_pad_top_ratio=1): - """ - Randomly pad img top and bottom sides with padding value to reach new_width as width - """ - h, w, c = img.shape - pad_height = max((new_height - h), 0) - max_pad_top = int(max_pad_top_ratio*pad_height) - pad_top = randint(0, min(pad_height, max_pad_top)) if pad_height != 0 and max_pad_top > 0 else 0 - pad_bottom = pad_height - pad_top - pad_top = np.ones((pad_top, w, c), dtype=img.dtype) * padding_value - pad_bottom = np.ones((pad_bottom, w, c), dtype=img.dtype) * padding_value - img = np.concatenate([pad_top, img, pad_bottom], axis=0) - return img - - -def pad_image_height_bottom(img, new_height, padding_value): - """ - Pad img to bottom side with padding value to reach new_height as height - """ - h, w, c = img.shape - pad_height = max((new_height - h), 0) - pad_bottom = np.ones((pad_height, w, c)) * padding_value - img = np.concatenate([img, pad_bottom], axis=0) - return img diff --git a/prediction-requirements.txt b/prediction-requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..86bfaeb8bf1dcef9ad52f4c64507482f5d942024 --- /dev/null +++ b/prediction-requirements.txt @@ -0,0 +1,8 @@ +arkindex-client==1.0.11 +editdistance==0.6.0 +fontTools==4.29.1 +imageio==2.16.0 +networkx==2.6.3 +tensorboard==0.2.1 +torchvision==0.12.0 +tqdm==4.62.3