diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 09a6111c0d86921e36d749493e3ea43749109fbd..3703391f03074a9efb507832eeba643a8e901b62 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,3 @@ -# Only run on our the DAN python module -files: '^dan' - repos: - repo: https://github.com/PyCQA/isort rev: 5.10.1 diff --git a/Datasets/dataset_formatters/extract_from_arkindex.py b/Datasets/dataset_formatters/extract_from_arkindex.py deleted file mode 100644 index 4bad36391d46e25a0f599362311ecef9c27f2da7..0000000000000000000000000000000000000000 --- a/Datasets/dataset_formatters/extract_from_arkindex.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Example of minimal usage: -# python extract_from_arkindex.py -# --corpus "AN-Simara annotations E (2022-06-20)" -# --parents-types folder -# --parents-names FRAN_IR_032031_4538.pdf -# --output-dir ../ - -""" - The extraction module - ====================== -""" - -import logging -import os - -import cv2 -import imageio.v2 as iio -from arkindex import ArkindexClient, options_from_env -from tqdm import tqdm - -from extraction_utils import arkindex_utils as ark -from extraction_utils import utils - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) - - -IMAGES_DIR = "./images/" # Path to the images directory. -LABELS_DIR = "./labels/" # Path to the labels directory. - -# Layout string to token -SEM_MATCHING_TOKENS_STR = { - "INTITULE": "ⓘ", - "DATE": "â““", - "COTE_SERIE": "â“¢", - "ANALYSE_COMPL.": "â“’", - "PRECISIONS_SUR_COTE": "â“Ÿ", - "COTE_ARTICLE": "â“" -} - -# Layout begin-token to end-token -SEM_MATCHING_TOKENS = { - "ⓘ": "â’¾", - "â““": "â’¹", - "â“¢": "Ⓢ", - "â“’": "â’¸", - "â“Ÿ": "â“…", - "â“": "â’¶" -} - - -if __name__ == '__main__': - args = utils.get_cli_args() - - # Get and initialize the parameters. - os.makedirs(IMAGES_DIR, exist_ok=True) - os.makedirs(LABELS_DIR, exist_ok=True) - - # Login to arkindex. - client = ArkindexClient(**options_from_env()) - - corpus = ark.retrieve_corpus(client, args.corpus) - subsets = ark.retrieve_subsets( - client, corpus, args.parents_types, args.parents_names - ) - - # Iterate over the subsets to find the page images and labels. - for subset in subsets: - - os.makedirs(os.path.join(args.output_dir, IMAGES_DIR, subset["name"]), exist_ok=True) - os.makedirs(os.path.join(args.output_dir, LABELS_DIR, subset["name"]), exist_ok=True) - - for page in tqdm( - client.paginate( - "ListElementChildren", id=subset["id"], type="page", recursive=True - ), - desc="Set " + subset["name"], - ): - - image = iio.imread(page["zone"]["url"]) - cv2.imwrite( - os.path.join(args.output_dir, IMAGES_DIR, subset['name'], f"{page['id']}.jpg"), - cv2.cvtColor(image, cv2.COLOR_BGR2RGB), - ) - - tr = client.request('ListTranscriptions', id=page['id'], worker_version=None)['results'] - tr = [one for one in tr if one['worker_version_id'] is None] - assert len(tr) == 1, page['id'] - - for one_tr in tr: - ent = client.request('ListTranscriptionEntities', id=one_tr['id'])['results'] - ent = [one for one in ent if one['worker_version_id'] is None] - if len(ent) == 0: - continue - else: - text = one_tr['text'] - - new_text = text - count = 0 - for e in ent: - start_token = SEM_MATCHING_TOKENS_STR[e['entity']['metas']['subtype']] - end_token = SEM_MATCHING_TOKENS[start_token] - new_text = new_text[:count+e['offset']] + start_token + new_text[count+e['offset']:] - count += 1 - new_text = new_text[:count+e['offset']+e['length']] + end_token + new_text[count+e['offset']+e['length']:] - count += 1 - - with open(os.path.join(args.output_dir, LABELS_DIR, subset['name'], f"{page['id']}.txt"), 'w') as f: - f.write(new_text) diff --git a/Datasets/dataset_formatters/extraction_utils/arkindex_utils.py b/Datasets/dataset_formatters/extraction_utils/arkindex_utils.py deleted file mode 100644 index 5216a7e0dd286cdfb53c73802324cd40cd07f7bd..0000000000000000000000000000000000000000 --- a/Datasets/dataset_formatters/extraction_utils/arkindex_utils.py +++ /dev/null @@ -1,67 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" - The arkindex_utils module - ====================== -""" - -import errno -import logging -import sys - -from apistar.exceptions import ErrorResponse - - -def retrieve_corpus(client, corpus_name: str) -> str: - """ - Retrieve the corpus id from the corpus name. - :param client: The arkindex client. - :param corpus_name: The name of the corpus to retrieve. - :return target_corpus: The id of the retrieved corpus. - """ - for corpus in client.request("ListCorpus"): - if corpus["name"] == corpus_name: - target_corpus = corpus["id"] - try: - logging.info(f"Corpus id retrieved: {target_corpus}") - except NameError: - logging.error(f"Corpus {corpus_name} not found") - sys.exit(errno.EINVAL) - - return target_corpus - - -def retrieve_subsets( - client, corpus: str, parents_types: list, parents_names: list -) -> list: - """ - Retrieve the requested subsets. - :param client: The arkindex client. - :param corpus: The id of the retrieved corpus. - :param parents_types: The types of parents of the elements to retrieve. - :param parents_names: The names of parents of the elements to retrieve. - :return subsets: The retrieved subsets. - """ - subsets = [] - for parent_type in parents_types: - try: - subsets.extend( - client.request("ListElements", corpus=corpus, type=parent_type)[ - "results" - ] - ) - except ErrorResponse as e: - logging.error(f"{e.content}: {parent_type}") - sys.exit(errno.EINVAL) - # Retrieve subsets with name in parents-names. If no parents-names given, keep all subsets. - if parents_names is not None: - logging.info(f"Retrieving {parents_names} subset(s)") - subsets = [subset for subset in subsets if subset["name"] in parents_names] - else: - logging.info("Retrieving all subsets") - - if len(subsets) == 0: - logging.info("No subset found") - - return subsets diff --git a/Datasets/dataset_formatters/extraction_utils/utils.py b/Datasets/dataset_formatters/extraction_utils/utils.py deleted file mode 100644 index fad5cdec84a0b68676804165a0606bbc63c8c7f9..0000000000000000000000000000000000000000 --- a/Datasets/dataset_formatters/extraction_utils/utils.py +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" - The utils module - ====================== -""" - -import argparse - -def get_cli_args(): - """ - Get the command-line arguments. - :return: The command-line arguments. - """ - parser = argparse.ArgumentParser( - description="Arkindex DAN Training Label Generation" - ) - - # Required arguments. - parser.add_argument( - "--corpus", - type=str, - help="Name of the corpus from which the data will be retrieved.", - required=True, - ) - parser.add_argument( - "--parents-types", - nargs="+", - type=str, - help="Type of parents of the elements.", - required=True, - ) - parser.add_argument( - "--output-dir", - type=str, - help="Path to the output directory.", - required=True, - ) - - # Optional arguments. - parser.add_argument( - "--parents-names", - nargs="+", - type=str, - help="Names of parents of the elements.", - default=None, - ) - return parser.parse_args() diff --git a/Datasets/dataset_formatters/generic_dataset_formatter.py b/Datasets/dataset_formatters/generic_dataset_formatter.py deleted file mode 100644 index ae55af026c841af4a9ab8ee08e3437bf9d16c034..0000000000000000000000000000000000000000 --- a/Datasets/dataset_formatters/generic_dataset_formatter.py +++ /dev/null @@ -1,89 +0,0 @@ -import os -import shutil -import tarfile -import pickle -import re -from PIL import Image -import numpy as np - - -class DatasetFormatter: - """ - Global pipeline/functions for dataset formatting - """ - - def __init__(self, dataset_name, level, extra_name="", set_names=["train", "valid", "test"]): - self.dataset_name = dataset_name - self.level = level - self.set_names = set_names - self.target_fold_path = os.path.join( - "Datasets", "formatted", "{}_{}{}".format(dataset_name, level, extra_name)) - self.map_datasets_files = dict() - self.extract_with_dirname = False - - def format(self): - self.init_format() - self.map_datasets_files[self.dataset_name][self.level]["format_function"]() - self.end_format() - - def init_format(self): - """ - Load and extracts needed files - """ - os.makedirs(self.target_fold_path, exist_ok=True) - - for set_name in self.set_names: - os.makedirs(os.path.join(self.target_fold_path, set_name), exist_ok=True) - - -class OCRDatasetFormatter(DatasetFormatter): - """ - Specific pipeline/functions for OCR/HTR dataset formatting - """ - - def __init__(self, source_dataset, level, extra_name="", set_names=["train", "valid", "test"]): - super(OCRDatasetFormatter, self).__init__(source_dataset, level, extra_name, set_names) - self.charset = set() - self.gt = dict() - for set_name in set_names: - self.gt[set_name] = dict() - - def format_text_label(self, label): - """ - Remove extra space or line break characters - """ - temp = re.sub("(\n)+", '\n', label) - return re.sub("( )+", ' ', temp).strip(" \n") - - def load_resize_save(self, source_path, target_path): - """ - Load image, apply resolution modification and save it - """ - shutil.copyfile(source_path, target_path) - - def resize(self, img, source_dpi, target_dpi): - """ - Apply resolution modification to image - """ - if source_dpi == target_dpi: - return img - if isinstance(img, np.ndarray): - h, w = img.shape[:2] - img = Image.fromarray(img) - else: - w, h = img.size - ratio = target_dpi / source_dpi - img = img.resize((int(w*ratio), int(h*ratio)), Image.BILINEAR) - return np.array(img) - - def end_format(self): - """ - Save label and charset files - """ - with open(os.path.join(self.target_fold_path, "labels.pkl"), "wb") as f: - pickle.dump({ - "ground_truth": self.gt, - "charset": sorted(list(self.charset)), - }, f) - with open(os.path.join(self.target_fold_path, "charset.pkl"), "wb") as f: - pickle.dump(sorted(list(self.charset)), f) diff --git a/Datasets/dataset_formatters/simara_formatter.py b/Datasets/dataset_formatters/simara_formatter.py deleted file mode 100644 index c60eb30f94cf569d3b2d9d9197a16756ee994c01..0000000000000000000000000000000000000000 --- a/Datasets/dataset_formatters/simara_formatter.py +++ /dev/null @@ -1,104 +0,0 @@ -from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter -import os -import numpy as np -from Datasets.dataset_formatters.utils_dataset import natural_sort -from PIL import Image -import xml.etree.ElementTree as ET -import re -from tqdm import tqdm - -# Layout string to token -SEM_MATCHING_TOKENS_STR = { - "INTITULE": "ⓘ", - "DATE": "â““", - "COTE_SERIE": "â“¢", - "ANALYSE_COMPL": "â“’", - "PRECISIONS_SUR_COTE": "â“Ÿ", - "COTE_ARTICLE": "â“" -} - -# Layout begin-token to end-token -SEM_MATCHING_TOKENS = { - "ⓘ": "â’¾", - "â““": "â’¹", - "â“¢": "Ⓢ", - "â“’": "â’¸", - "â“Ÿ": "â“…", - "â“": "â’¶" -} - -class SimaraDatasetFormatter(OCRDatasetFormatter): - def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=True): - super(SimaraDatasetFormatter, self).__init__("simara", level, "_sem" if sem_token else "", set_names) - - self.dpi = dpi - self.sem_token = sem_token - self.map_datasets_files.update({ - "simara": { - # (1,050 for train, 100 for validation and 100 for test) - "page": { - "format_function": self.format_simara_page, - }, - } - }) - self.matching_tokens_str = SEM_MATCHING_TOKENS_STR - self.matching_tokens = SEM_MATCHING_TOKENS - - def preformat_simara_page(self): - """ - Extract all information from dataset and correct some annotations - """ - dataset = { - "train": list(), - "valid": list(), - "test": list() - } - img_folder_path = os.path.join("Datasets", "raw", "simara", "images") - labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels") - sem_labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels_sem") - train_files = [ - os.path.join(labels_folder_path, 'train', name) - for name in os.listdir(os.path.join(sem_labels_folder_path, 'train'))] - valid_files = [ - os.path.join(labels_folder_path, 'valid', name) - for name in os.listdir(os.path.join(sem_labels_folder_path, 'valid'))] - test_files = [ - os.path.join(labels_folder_path, 'test', name) - for name in os.listdir(os.path.join(sem_labels_folder_path, 'test'))] - for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]): - for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)): - with open(label_file, 'r') as f: - text = f.read() - with open(label_file.replace('labels', 'labels_sem'), 'r') as f: - sem_text = f.read() - dataset[set_name].append({ - "img_path": os.path.join( - img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'jpg')), - "label": text, - "sem_label": sem_text, - }) - print(dataset['test'], len(dataset['test'])) - return dataset - - def format_simara_page(self): - """ - Format simara page dataset - """ - dataset = self.preformat_simara_page() - for set_name in self.set_names: - fold = os.path.join(self.target_fold_path, set_name) - for sample in tqdm(dataset[set_name], desc='Formatting '+set_name): - new_name = sample['img_path'].split('/')[-1] - new_img_path = os.path.join(fold, new_name) - self.load_resize_save(sample["img_path"], new_img_path)#, 300, self.dpi) - page = { - "text": sample["label"] if not self.sem_token else sample["sem_label"], - } - self.charset = self.charset.union(set(page["text"])) - self.gt[set_name][new_name] = page - - -if __name__ == "__main__": - - SimaraDatasetFormatter("page", sem_token=True).format() - #SimaraDatasetFormatter("page", sem_token=False).format() diff --git a/Datasets/dataset_formatters/utils_dataset.py b/Datasets/dataset_formatters/utils_dataset.py deleted file mode 100644 index 962274da3a806f729904d839481f294e1d5d6ca2..0000000000000000000000000000000000000000 --- a/Datasets/dataset_formatters/utils_dataset.py +++ /dev/null @@ -1,7 +0,0 @@ -import re - - -def natural_sort(l): - convert = lambda text: int(text) if text.isdigit() else text.lower() - alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key)] - return sorted(l, key=alphanum_key) \ No newline at end of file diff --git a/OCR/document_OCR/dan/main_dan.py b/OCR/document_OCR/dan/main_dan.py deleted file mode 100644 index 2a15f0aa8ee4ea4f4b9bd9297a74180827ce71a5..0000000000000000000000000000000000000000 --- a/OCR/document_OCR/dan/main_dan.py +++ /dev/null @@ -1,207 +0,0 @@ -import os -import sys -DOSSIER_COURRANT = os.path.dirname(os.path.abspath(__file__)) -DOSSIER_PARENT = os.path.dirname(DOSSIER_COURRANT) -sys.path.append(os.path.dirname(DOSSIER_PARENT)) -sys.path.append(os.path.dirname(os.path.dirname(DOSSIER_PARENT))) -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(DOSSIER_PARENT)))) -from torch.optim import Adam -from basic.transforms import aug_config -from OCR.ocr_dataset_manager import OCRDataset, OCRDatasetManager -from OCR.document_OCR.dan.trainer_dan import Manager -from dan.decoder import GlobalHTADecoder -from dan.models import FCN_Encoder -from basic.scheduler import exponential_dropout_scheduler, linear_scheduler -import torch -import numpy as np -import random -import torch.multiprocessing as mp - - -def train_and_test(rank, params): - torch.manual_seed(0) - torch.cuda.manual_seed(0) - np.random.seed(0) - random.seed(0) - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - - params["training_params"]["ddp_rank"] = rank - model = Manager(params) - model.load_model() - - model.train() - - # load weights giving best CER on valid set - model.params["training_params"]["load_epoch"] = "best" - model.load_model() - - metrics = ["cer", "wer", "time", "map_cer", "loer"] - for dataset_name in params["dataset_params"]["datasets"].keys(): - for set_name in ["test", "valid", "train"]: - model.predict("{}-{}".format(dataset_name, set_name), [(dataset_name, set_name), ], metrics, output=True) - - -if __name__ == "__main__": - - dataset_name = "simara" # ["RIMES", "READ_2016"] - dataset_level = "page" # ["page", "double_page"] - dataset_variant = "_sem" - - # max number of lines for synthetic documents - max_nb_lines = { - "RIMES": 40, - "READ_2016": 30, - } - - params = { - "dataset_params": { - "dataset_manager": OCRDatasetManager, - "dataset_class": OCRDataset, - "datasets": { - dataset_name: "../../../Datasets/formatted/{}_{}{}".format(dataset_name, dataset_level, dataset_variant), - }, - "train": { - "name": "{}-train".format(dataset_name), - "datasets": [(dataset_name, "train"), ], - }, - "valid": { - "{}-valid".format(dataset_name): [(dataset_name, "valid"), ], - }, - "config": { - "load_in_memory": True, # Load all images in CPU memory - "worker_per_gpu": 4, # Num of parallel processes per gpu for data loading - "width_divisor": 8, # Image width will be divided by 8 - "height_divisor": 32, # Image height will be divided by 32 - "padding_value": 0, # Image padding value - "padding_token": None, # Label padding value - "charset_mode": "seq2seq", # add end-of-transcription ans start-of-transcription tokens to charset - "constraints": ["add_eot", "add_sot"], # add end-of-transcription ans start-of-transcription tokens in labels - "normalize": True, # Normalize with mean and variance of training dataset - "preprocessings": [ - { - "type": "to_RGB", - # if grayscaled image, produce RGB one (3 channels with same value) otherwise do nothing - }, - ], - "augmentation": aug_config(0.9, 0.1), - "synthetic_data": None, - #"synthetic_data": { - # "init_proba": 0.9, # begin proba to generate synthetic document - # "end_proba": 0.2, # end proba to generate synthetic document - # "num_steps_proba": 200000, # linearly decrease the percent of synthetic document from 90% to 20% through 200000 samples - # "proba_scheduler_function": linear_scheduler, # decrease proba rate linearly - # "start_scheduler_at_max_line": True, # start decreasing proba only after curriculum reach max number of lines - # "dataset_level": dataset_level, - # "curriculum": True, # use curriculum learning (slowly increase number of lines per synthetic samples) - # "crop_curriculum": True, # during curriculum learning, crop images under the last text line - # "curr_start": 0, # start curriculum at iteration - # "curr_step": 10000, # interval to increase the number of lines for curriculum learning - # "min_nb_lines": 1, # initial number of lines for curriculum learning - # "max_nb_lines": max_nb_lines[dataset_name], # maximum number of lines for curriculum learning - # "padding_value": 255, - # # config for synthetic line generation - # "config": { - # "background_color_default": (255, 255, 255), - # "background_color_eps": 15, - # "text_color_default": (0, 0, 0), - # "text_color_eps": 15, - # "font_size_min": 35, - # "font_size_max": 45, - # "color_mode": "RGB", - # "padding_left_ratio_min": 0.00, - # "padding_left_ratio_max": 0.05, - # "padding_right_ratio_min": 0.02, - # "padding_right_ratio_max": 0.2, - # "padding_top_ratio_min": 0.02, - # "padding_top_ratio_max": 0.1, - # "padding_bottom_ratio_min": 0.02, - # "padding_bottom_ratio_max": 0.1, - # }, - #} - } - }, - - "model_params": { - "models": { - "encoder": FCN_Encoder, - "decoder": GlobalHTADecoder, - }, - #"transfer_learning": None, - "transfer_learning": { - # model_name: [state_dict_name, checkpoint_path, learnable, strict] - "encoder": ["encoder", "dan_rimes_page.pt", True, True], - "decoder": ["decoder", "dan_rimes_page.pt", True, False], - }, - "transfered_charset": True, # Transfer learning of the decision layer based on charset of the line HTR model - "additional_tokens": 1, # for decision layer = [<eot>, ], only for transfered charset - - "input_channels": 3, # number of channels of input image - "dropout": 0.5, # dropout rate for encoder - "enc_dim": 256, # dimension of extracted features - "nb_layers": 5, # encoder - "h_max": 500, # maximum height for encoder output (for 2D positional embedding) - "w_max": 1000, # maximum width for encoder output (for 2D positional embedding) - "l_max": 15000, # max predicted sequence (for 1D positional embedding) - "dec_num_layers": 8, # number of transformer decoder layers - "dec_num_heads": 4, # number of heads in transformer decoder layers - "dec_res_dropout": 0.1, # dropout in transformer decoder layers - "dec_pred_dropout": 0.1, # dropout rate before decision layer - "dec_att_dropout": 0.1, # dropout rate in multi head attention - "dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers - "use_2d_pe": True, # use 2D positional embedding - "use_1d_pe": True, # use 1D positional embedding - "use_lstm": False, - "attention_win": 100, # length of attention window - # Curriculum dropout - "dropout_scheduler": { - "function": exponential_dropout_scheduler, - "T": 5e4, - } - - }, - - "training_params": { - "output_folder": "dan_simara_page", # folder name for checkpoint and results - "max_nb_epochs": 50000, # maximum number of epochs before to stop - "max_training_time": 3600 * 24 * 1.9, # maximum time before to stop (in seconds) - "load_epoch": "last", # ["best", "last"]: last to continue training, best to evaluate - "interval_save_weights": None, # None: keep best and last only - "batch_size": 2, # mini-batch size for training - "valid_batch_size": 4, # mini-batch size for valdiation - "use_ddp": False, # Use DistributedDataParallel - "ddp_port": "20027", - "use_amp": True, # Enable automatic mix-precision - "nb_gpu": torch.cuda.device_count(), - "optimizers": { - "all": { - "class": Adam, - "args": { - "lr": 0.0001, - "amsgrad": False, - } - }, - }, - "lr_schedulers": None, # Learning rate schedulers - "eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not - "eval_on_valid_interval": 5, # Interval (in epochs) to evaluate during training - "focus_metric": "cer", # Metrics to focus on to determine best epoch - "expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value - "set_name_focus_metric": "{}-valid".format(dataset_name), # Which dataset to focus on to select best weights - "train_metrics": ["loss_ce", "cer", "wer", "syn_max_lines"], # Metrics name for training - "eval_metrics": ["cer", "wer", "map_cer"], # Metrics name for evaluation on validation set during training - "force_cpu": False, # True for debug purposes - "max_char_prediction": 1000, # max number of token prediction - # Keep teacher forcing rate to 20% during whole training - "teacher_forcing_scheduler": { - "min_error_rate": 0.2, - "max_error_rate": 0.2, - "total_num_steps": 5e4 - }, - }, - } - - if params["training_params"]["use_ddp"] and not params["training_params"]["force_cpu"]: - mp.spawn(train_and_test, args=(params,), nprocs=params["training_params"]["nb_gpu"]) - else: - train_and_test(0, params) diff --git a/OCR/document_OCR/dan/trainer_dan.py b/OCR/document_OCR/dan/trainer_dan.py deleted file mode 100644 index d8fd84324e88ff44e2983abfdcf53ac70709c08b..0000000000000000000000000000000000000000 --- a/OCR/document_OCR/dan/trainer_dan.py +++ /dev/null @@ -1,165 +0,0 @@ -from OCR.ocr_manager import OCRManager -from torch.nn import CrossEntropyLoss -import torch -from dan.ocr_utils import LM_ind_to_str -import numpy as np -from torch.cuda.amp import autocast -import time - - -class Manager(OCRManager): - - def __init__(self, params): - super(Manager, self).__init__(params) - - def load_save_info(self, info_dict): - if "curriculum_config" in info_dict.keys(): - if self.dataset.train_dataset is not None: - self.dataset.train_dataset.curriculum_config = info_dict["curriculum_config"] - - def add_save_info(self, info_dict): - info_dict["curriculum_config"] = self.dataset.train_dataset.curriculum_config - return info_dict - - def get_init_hidden(self, batch_size): - num_layers = 1 - hidden_size = self.params["model_params"]["enc_dim"] - return torch.zeros(num_layers, batch_size, hidden_size), torch.zeros(num_layers, batch_size, hidden_size) - - def apply_teacher_forcing(self, y, y_len, error_rate): - y_error = y.clone() - for b in range(len(y_len)): - for i in range(1, y_len[b]): - if np.random.rand() < error_rate and y[b][i] != self.dataset.tokens["pad"]: - y_error[b][i] = np.random.randint(0, len(self.dataset.charset)+2) - return y_error, y_len - - def train_batch(self, batch_data, metric_names): - loss_func = CrossEntropyLoss(ignore_index=self.dataset.tokens["pad"]) - - sum_loss = 0 - x = batch_data["imgs"].to(self.device) - y = batch_data["labels"].to(self.device) - reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]] - y_len = batch_data["labels_len"] - - # add errors in teacher forcing - if "teacher_forcing_error_rate" in self.params["training_params"] and self.params["training_params"]["teacher_forcing_error_rate"] is not None: - error_rate = self.params["training_params"]["teacher_forcing_error_rate"] - simulated_y_pred, y_len = self.apply_teacher_forcing(y, y_len, error_rate) - elif "teacher_forcing_scheduler" in self.params["training_params"]: - error_rate = self.params["training_params"]["teacher_forcing_scheduler"]["min_error_rate"] + min(self.latest_step, self.params["training_params"]["teacher_forcing_scheduler"]["total_num_steps"]) * (self.params["training_params"]["teacher_forcing_scheduler"]["max_error_rate"]-self.params["training_params"]["teacher_forcing_scheduler"]["min_error_rate"]) / self.params["training_params"]["teacher_forcing_scheduler"]["total_num_steps"] - simulated_y_pred, y_len = self.apply_teacher_forcing(y, y_len, error_rate) - else: - simulated_y_pred = y - - with autocast(enabled=self.params["training_params"]["use_amp"]): - hidden_predict = None - cache = None - - raw_features = self.models["encoder"](x) - features_size = raw_features.size() - b, c, h, w = features_size - - pos_features = self.models["decoder"].features_updater.get_pos_features(raw_features) - features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(2, 0, 1) - enhanced_features = pos_features - enhanced_features = torch.flatten(enhanced_features, start_dim=2, end_dim=3).permute(2, 0, 1) - output, pred, hidden_predict, cache, weights = self.models["decoder"](features, enhanced_features, - simulated_y_pred[:, :-1], - reduced_size, - [max(y_len) for _ in range(b)], - features_size, - start=0, - hidden_predict=hidden_predict, - cache=cache, - keep_all_weights=True) - - loss_ce = loss_func(pred, y[:, 1:]) - sum_loss += loss_ce - with autocast(enabled=False): - self.backward_loss(sum_loss) - self.step_optimizers() - self.zero_optimizers() - predicted_tokens = torch.argmax(pred, dim=1).detach().cpu().numpy() - predicted_tokens = [predicted_tokens[i, :y_len[i]] for i in range(b)] - str_x = [LM_ind_to_str(self.dataset.charset, t, oov_symbol="") for t in predicted_tokens] - - values = { - "nb_samples": b, - "str_y": batch_data["raw_labels"], - "str_x": str_x, - "loss": sum_loss.item(), - "loss_ce": loss_ce.item(), - "syn_max_lines": self.dataset.train_dataset.get_syn_max_lines() if self.params["dataset_params"]["config"]["synthetic_data"] else 0, - } - - return values - - def evaluate_batch(self, batch_data, metric_names): - x = batch_data["imgs"].to(self.device) - reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]] - - max_chars = self.params["training_params"]["max_char_prediction"] - - start_time = time.time() - with autocast(enabled=self.params["training_params"]["use_amp"]): - b = x.size(0) - reached_end = torch.zeros((b, ), dtype=torch.bool, device=self.device) - prediction_len = torch.zeros((b, ), dtype=torch.int, device=self.device) - predicted_tokens = torch.ones((b, 1), dtype=torch.long, device=self.device) * self.dataset.tokens["start"] - predicted_tokens_len = torch.ones((b, ), dtype=torch.int, device=self.device) - - whole_output = list() - confidence_scores = list() - cache = None - hidden_predict = None - if b > 1: - features_list = list() - for i in range(b): - pos = batch_data["imgs_position"] - features_list.append(self.models["encoder"](x[i:i+1, :, pos[i][0][0]:pos[i][0][1], pos[i][1][0]:pos[i][1][1]])) - max_height = max([f.size(2) for f in features_list]) - max_width = max([f.size(3) for f in features_list]) - features = torch.zeros((b, features_list[0].size(1), max_height, max_width), device=self.device, dtype=features_list[0].dtype) - for i in range(b): - features[i, :, :features_list[i].size(2), :features_list[i].size(3)] = features_list[i] - else: - features = self.models["encoder"](x) - features_size = features.size() - coverage_vector = torch.zeros((features.size(0), 1, features.size(2), features.size(3)), device=self.device) - pos_features = self.models["decoder"].features_updater.get_pos_features(features) - features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(2, 0, 1) - enhanced_features = pos_features - enhanced_features = torch.flatten(enhanced_features, start_dim=2, end_dim=3).permute(2, 0, 1) - - for i in range(0, max_chars): - output, pred, hidden_predict, cache, weights = self.models["decoder"](features, enhanced_features, predicted_tokens, reduced_size, predicted_tokens_len, features_size, start=0, hidden_predict=hidden_predict, cache=cache, num_pred=1) - whole_output.append(output) - confidence_scores.append(torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values) - coverage_vector = torch.clamp(coverage_vector + weights, 0, 1) - predicted_tokens = torch.cat([predicted_tokens, torch.argmax(pred[:, :, -1], dim=1, keepdim=True)], dim=1) - reached_end = torch.logical_or(reached_end, torch.eq(predicted_tokens[:, -1], self.dataset.tokens["end"])) - predicted_tokens_len += 1 - - prediction_len[reached_end == False] = i + 1 - if torch.all(reached_end): - break - - confidence_scores = torch.cat(confidence_scores, dim=1).cpu().detach().numpy() - predicted_tokens = predicted_tokens[:, 1:] - prediction_len[torch.eq(reached_end, False)] = max_chars - 1 - predicted_tokens = [predicted_tokens[i, :prediction_len[i]] for i in range(b)] - confidence_scores = [confidence_scores[i, :prediction_len[i]].tolist() for i in range(b)] - str_x = [LM_ind_to_str(self.dataset.charset, t, oov_symbol="") for t in predicted_tokens] - - process_time = time.time() - start_time - - values = { - "nb_samples": b, - "str_y": batch_data["raw_labels"], - "str_x": str_x, - "confidence_score": confidence_scores, - "time": process_time, - } - return values diff --git a/OCR/line_OCR/ctc/main_line_ctc.py b/OCR/line_OCR/ctc/main_line_ctc.py deleted file mode 100644 index bb8fa1d6c277b8978bb251eb0259e17603e9e28e..0000000000000000000000000000000000000000 --- a/OCR/line_OCR/ctc/main_line_ctc.py +++ /dev/null @@ -1,173 +0,0 @@ - -import os -import sys -from os.path import dirname -DOSSIER_COURRANT = dirname(os.path.abspath(__file__)) -ROOT_FOLDER = dirname(dirname(dirname(DOSSIER_COURRANT))) -sys.path.append(ROOT_FOLDER) -from OCR.line_OCR.ctc.trainer_line_ctc import TrainerLineCTC -from OCR.line_OCR.ctc.models_line_ctc import Decoder -from dan.models import FCN_Encoder -from torch.optim import Adam -from basic.transforms import line_aug_config -from basic.scheduler import exponential_dropout_scheduler, exponential_scheduler -from OCR.ocr_dataset_manager import OCRDataset, OCRDatasetManager -import torch.multiprocessing as mp -import torch -import numpy as np -import random - - -def train_and_test(rank, params): - torch.manual_seed(0) - torch.cuda.manual_seed(0) - np.random.seed(0) - random.seed(0) - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - - params["training_params"]["ddp_rank"] = rank - model = TrainerLineCTC(params) - model.load_model() - - # Model trains until max_time_training or max_nb_epochs is reached - model.train() - - # load weights giving best CER on valid set - model.params["training_params"]["load_epoch"] = "best" - model.load_model() - - # compute metrics on train, valid and test sets (in eval conditions) - metrics = ["cer", "wer", "time", ] - for dataset_name in params["dataset_params"]["datasets"].keys(): - for set_name in ["test", "valid", "train", ]: - model.predict("{}-{}".format(dataset_name, set_name), [(dataset_name, set_name), ], metrics, output=True) - - -def main(): - dataset_name = "READ_2016" # ["RIMES", "READ_2016"] - dataset_level = "syn_line" - params = { - "dataset_params": { - "dataset_manager": OCRDatasetManager, - "dataset_class": OCRDataset, - "datasets": { - dataset_name: "../../../Datasets/formatted/{}_{}".format(dataset_name, dataset_level), - }, - "train": { - "name": "{}-train".format(dataset_name), - "datasets": [(dataset_name, "train"), ], - }, - "valid": { - "{}-valid".format(dataset_name): [(dataset_name, "valid"), ], - }, - "config": { - "load_in_memory": True, # Load all images in CPU memory - "worker_per_gpu": 8, # Num of parallel processes per gpu for data loading - "width_divisor": 8, # Image width will be divided by 8 - "height_divisor": 32, # Image height will be divided by 32 - "padding_value": 0, # Image padding value - "padding_token": 1000, # Label padding value (None: default value is chosen) - "padding_mode": "br", # Padding at bottom and right - "charset_mode": "CTC", # add blank token - "constraints": ["CTC_line", ], # Padding for CTC requirements if necessary - "normalize": True, # Normalize with mean and variance of training dataset - "padding": { - "min_height": "max", # Pad to reach max height of training samples - "min_width": "max", # Pad to reach max width of training samples - "min_pad": None, - "max_pad": None, - "mode": "br", # Padding at bottom and right - "train_only": False, # Add padding at training time and evaluation time - }, - "preprocessings": [ - { - "type": "to_RGB", - # if grayscale image, produce RGB one (3 channels with same value) otherwise do nothing - }, - ], - # Augmentation techniques to use at training time - "augmentation": line_aug_config(0.9, 0.1), - # - "synthetic_data": { - "mode": "line_hw_to_printed", - "init_proba": 1, - "end_proba": 1, - "num_steps_proba": 1e5, - "proba_scheduler_function": exponential_scheduler, - "config": { - "background_color_default": (255, 255, 255), - "background_color_eps": 15, - "text_color_default": (0, 0, 0), - "text_color_eps": 15, - "font_size_min": 30, - "font_size_max": 50, - "color_mode": "RGB", - "padding_left_ratio_min": 0.02, - "padding_left_ratio_max": 0.1, - "padding_right_ratio_min": 0.02, - "padding_right_ratio_max": 0.1, - "padding_top_ratio_min": 0.02, - "padding_top_ratio_max": 0.2, - "padding_bottom_ratio_min": 0.02, - "padding_bottom_ratio_max": 0.2, - }, - }, - } - }, - - "model_params": { - # Model classes to use for each module - "models": { - "encoder": FCN_Encoder, - "decoder": Decoder, - }, - "transfer_learning": None, - "input_channels": 3, # 1 for grayscale images, 3 for RGB ones (or grayscale as RGB) - "enc_size": 256, - "dropout_scheduler": { - "function": exponential_dropout_scheduler, - "T": 5e4, - }, - "dropout": 0.5, - }, - - "training_params": { - "output_folder": "FCN_read_2016_line_syn", # folder names for logs and weigths - "max_nb_epochs": 10000, # max number of epochs for the training - "max_training_time": 3600 * 24 * 1.9, # max training time limit (in seconds) - "load_epoch": "last", # ["best", "last"], to load weights from best epoch or last trained epoch - "interval_save_weights": None, # None: keep best and last only - "use_ddp": False, # Use DistributedDataParallel - "use_amp": True, # Enable automatic mix-precision - "nb_gpu": torch.cuda.device_count(), - "batch_size": 16, # mini-batch size per GPU - "optimizers": { - "all": { - "class": Adam, - "args": { - "lr": 0.0001, - "amsgrad": False, - } - } - }, - "lr_schedulers": None, # Learning rate schedulers - "eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not - "eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training - "focus_metric": "cer", # Metrics to focus on to determine best epoch - "expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value - "set_name_focus_metric": "{}-valid".format(dataset_name), # Which dataset to focus on to select best weights - "train_metrics": ["loss_ctc", "cer", "wer"], # Metrics name for training - "eval_metrics": ["loss_ctc", "cer", "wer"], # Metrics name for evaluation on validation set during training - "force_cpu": False, # True for debug purposes to run on cpu only - }, - } - - if params["training_params"]["use_ddp"] and not params["training_params"]["force_cpu"]: - mp.spawn(train_and_test, args=(params,), nprocs=params["training_params"]["nb_gpu"]) - else: - train_and_test(0, params) - - -if __name__ == "__main__": - main() diff --git a/OCR/line_OCR/ctc/main_syn_line.py b/OCR/line_OCR/ctc/main_syn_line.py deleted file mode 100644 index c16f9de5ab1fb38ce18486146f997ba60baedc61..0000000000000000000000000000000000000000 --- a/OCR/line_OCR/ctc/main_syn_line.py +++ /dev/null @@ -1,148 +0,0 @@ - -import os -import sys -from os.path import dirname -DOSSIER_COURRANT = dirname(os.path.abspath(__file__)) -ROOT_FOLDER = dirname(dirname(dirname(DOSSIER_COURRANT))) -sys.path.append(ROOT_FOLDER) -from OCR.line_OCR.ctc.trainer_line_ctc import TrainerLineCTC -from OCR.line_OCR.ctc.models_line_ctc import Decoder -from dan.models import FCN_Encoder -from torch.optim import Adam -from basic.transforms import line_aug_config -from basic.scheduler import exponential_dropout_scheduler, exponential_scheduler -from OCR.ocr_dataset_manager import OCRDataset, OCRDatasetManager -import torch.multiprocessing as mp -import torch -import numpy as np -import random - - -def train_and_test(rank, params): - torch.manual_seed(0) - torch.cuda.manual_seed(0) - np.random.seed(0) - random.seed(0) - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - - params["training_params"]["ddp_rank"] = rank - model = TrainerLineCTC(params) - - model.generate_syn_line_dataset("READ_2016_syn_line") # ["RIMES_syn_line", "READ_2016_syn_line"] - - -def main(): - dataset_name = "READ_2016" # ["RIMES", "READ_2016"] - dataset_level = "page" - params = { - "dataset_params": { - "dataset_manager": OCRDatasetManager, - "dataset_class": OCRDataset, - "datasets": { - dataset_name: "../../../Datasets/formatted/{}_{}".format(dataset_name, dataset_level), - }, - "train": { - "name": "{}-train".format(dataset_name), - "datasets": [(dataset_name, "train"), ], - }, - "valid": { - "{}-valid".format(dataset_name): [(dataset_name, "valid"), ], - }, - "config": { - "load_in_memory": False, # Load all images in CPU memory - "worker_per_gpu": 4, - "width_divisor": 8, # Image width will be divided by 8 - "height_divisor": 32, # Image height will be divided by 32 - "padding_value": 0, # Image padding value - "padding_token": 1000, # Label padding value (None: default value is chosen) - "padding_mode": "br", # Padding at bottom and right - "charset_mode": "CTC", # add blank token - "constraints": [], # Padding for CTC requirements if necessary - "normalize": True, # Normalize with mean and variance of training dataset - "preprocessings": [], - # Augmentation techniques to use at training time - "augmentation": line_aug_config(0.9, 0.1), - # - "synthetic_data": { - "mode": "line_hw_to_printed", - "init_proba": 1, - "end_proba": 1, - "num_steps_proba": 1e5, - "proba_scheduler_function": exponential_scheduler, - "config": { - "background_color_default": (255, 255, 255), - "background_color_eps": 15, - "text_color_default": (0, 0, 0), - "text_color_eps": 15, - "font_size_min": 30, - "font_size_max": 50, - "color_mode": "RGB", - "padding_left_ratio_min": 0.02, - "padding_left_ratio_max": 0.1, - "padding_right_ratio_min": 0.02, - "padding_right_ratio_max": 0.1, - "padding_top_ratio_min": 0.02, - "padding_top_ratio_max": 0.2, - "padding_bottom_ratio_min": 0.02, - "padding_bottom_ratio_max": 0.2, - }, - }, - } - }, - - "model_params": { - # Model classes to use for each module - "models": { - "encoder": FCN_Encoder, - "decoder": Decoder, - }, - "transfer_learning": None, - "input_channels": 3, # 1 for grayscale images, 3 for RGB ones (or grayscale as RGB) - "enc_size": 256, - "dropout_scheduler": { - "function": exponential_dropout_scheduler, - "T": 5e4, - }, - "dropout": 0.5, - }, - - "training_params": { - "output_folder": "FCN_Encoder_read_syn_line_all_pad_max_cursive", # folder names for logs and weigths - "max_nb_epochs": 10000, # max number of epochs for the training - "max_training_time": 3600 * 24 * 1.9, # max training time limit (in seconds) - "load_epoch": "last", # ["best", "last"], to load weights from best epoch or last trained epoch - "interval_save_weights": None, # None: keep best and last only - "use_ddp": False, # Use DistributedDataParallel - "use_amp": True, # Enable automatic mix-precision - "nb_gpu": torch.cuda.device_count(), - "batch_size": 1, # mini-batch size per GPU - "optimizers": { - "all": { - "class": Adam, - "args": { - "lr": 0.0001, - "amsgrad": False, - } - } - }, - "lr_schedulers": None, - "eval_on_valid": True, # Whether to eval and logs metrics on validation set during training or not - "eval_on_valid_interval": 2, # Interval (in epochs) to evaluate during training - "focus_metric": "cer", # Metrics to focus on to determine best epoch - "expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value - "set_name_focus_metric": "{}-valid".format(dataset_name), - "train_metrics": ["loss_ctc", "cer", "wer"], # Metrics name for training - "eval_metrics": ["loss_ctc", "cer", "wer"], # Metrics name for evaluation on validation set during training - "force_cpu": False, # True for debug purposes to run on cpu only - }, - } - - if params["training_params"]["use_ddp"] and not params["training_params"]["force_cpu"]: - mp.spawn(train_and_test, args=(params,), nprocs=params["training_params"]["nb_gpu"]) - else: - train_and_test(0, params) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/OCR/line_OCR/ctc/models_line_ctc.py b/OCR/line_OCR/ctc/models_line_ctc.py deleted file mode 100644 index c13c072ebd10e810d41b8aab1e318c2170637ea0..0000000000000000000000000000000000000000 --- a/OCR/line_OCR/ctc/models_line_ctc.py +++ /dev/null @@ -1,19 +0,0 @@ - -from torch.nn.functional import log_softmax -from torch.nn import AdaptiveMaxPool2d, Conv1d -from torch.nn import Module - - -class Decoder(Module): - def __init__(self, params): - super(Decoder, self).__init__() - - self.vocab_size = params["vocab_size"] - - self.ada_pool = AdaptiveMaxPool2d((1, None)) - self.end_conv = Conv1d(in_channels=params["enc_size"], out_channels=self.vocab_size+1, kernel_size=1) - - def forward(self, x): - x = self.ada_pool(x).squeeze(2) - x = self.end_conv(x) - return log_softmax(x, dim=1) diff --git a/OCR/line_OCR/ctc/trainer_line_ctc.py b/OCR/line_OCR/ctc/trainer_line_ctc.py deleted file mode 100644 index eba12d81ecb4a417115121822d640be374cf185a..0000000000000000000000000000000000000000 --- a/OCR/line_OCR/ctc/trainer_line_ctc.py +++ /dev/null @@ -1,96 +0,0 @@ - -from basic.metric_manager import MetricManager -from OCR.ocr_manager import OCRManager -from dan.ocr_utils import LM_ind_to_str -import torch -from torch.cuda.amp import autocast -from torch.nn import CTCLoss -import re -import time - - -class TrainerLineCTC(OCRManager): - - def __init__(self, params): - super(TrainerLineCTC, self).__init__(params) - - def train_batch(self, batch_data, metric_names): - """ - Forward and backward pass for training - """ - x = batch_data["imgs"].to(self.device) - y = batch_data["labels"].to(self.device) - x_reduced_len = [s[1] for s in batch_data["imgs_reduced_shape"]] - y_len = batch_data["labels_len"] - - loss_ctc = CTCLoss(blank=self.dataset.tokens["blank"]) - self.zero_optimizers() - - with autocast(enabled=self.params["training_params"]["use_amp"]): - x = self.models["encoder"](x) - global_pred = self.models["decoder"](x) - - loss = loss_ctc(global_pred.permute(2, 0, 1), y, x_reduced_len, y_len) - - self.backward_loss(loss) - - self.step_optimizers() - pred = torch.argmax(global_pred, dim=1).cpu().numpy() - - values = { - "nb_samples": len(batch_data["raw_labels"]), - "loss_ctc": loss.item(), - "str_x": self.pred_to_str(pred, x_reduced_len), - "str_y": batch_data["raw_labels"] - } - - return values - - def evaluate_batch(self, batch_data, metric_names): - """ - Forward pass only for validation and test - """ - x = batch_data["imgs"].to(self.device) - y = batch_data["labels"].to(self.device) - x_reduced_len = [s[1] for s in batch_data["imgs_reduced_shape"]] - y_len = batch_data["labels_len"] - - loss_ctc = CTCLoss(blank=self.dataset.tokens["blank"]) - - start_time = time.time() - with autocast(enabled=self.params["training_params"]["use_amp"]): - x = self.models["encoder"](x) - global_pred = self.models["decoder"](x) - - loss = loss_ctc(global_pred.permute(2, 0, 1), y, x_reduced_len, y_len) - pred = torch.argmax(global_pred, dim=1).cpu().numpy() - str_x = self.pred_to_str(pred, x_reduced_len) - - process_time =time.time() - start_time - - values = { - "nb_samples": len(batch_data["raw_labels"]), - "loss_ctc": loss.item(), - "str_x": str_x, - "str_y": batch_data["raw_labels"], - "time": process_time - } - return values - - def ctc_remove_successives_identical_ind(self, ind): - res = [] - for i in ind: - if res and res[-1] == i: - continue - res.append(i) - return res - - def pred_to_str(self, pred, pred_len): - """ - convert prediction tokens to string - """ - ind_x = [pred[i][:pred_len[i]] for i in range(pred.shape[0])] - ind_x = [self.ctc_remove_successives_identical_ind(t) for t in ind_x] - str_x = [LM_ind_to_str(self.dataset.charset, t, oov_symbol="") for t in ind_x] - str_x = [re.sub("( )+", ' ', t).strip(" ") for t in str_x] - return str_x diff --git a/OCR/ocr_dataset_manager.py b/OCR/ocr_dataset_manager.py deleted file mode 100644 index ed7c087cba8a3b5ef7d435a8c73eba01d35ff40f..0000000000000000000000000000000000000000 --- a/OCR/ocr_dataset_manager.py +++ /dev/null @@ -1,1036 +0,0 @@ -# Copyright Université de Rouen Normandie (1), INSA Rouen (2), -# tutelles du laboratoire LITIS (1 et 2) -# contributors : -# - Denis Coquenet -# -# -# This software is a computer program written in Python whose purpose is to -# provide public implementation of deep learning works, in pytorch. -# -# This software is governed by the CeCILL-C license under French law and -# abiding by the rules of distribution of free software. You can use, -# modify and/ or redistribute the software under the terms of the CeCILL-C -# license as circulated by CEA, CNRS and INRIA at the following URL -# "http://www.cecill.info". -# -# As a counterpart to the access to the source code and rights to copy, -# modify and redistribute granted by the license, users are provided only -# with a limited warranty and the software's author, the holder of the -# economic rights, and the successive licensors have only limited -# liability. -# -# In this respect, the user's attention is drawn to the risks associated -# with loading, using, modifying and/or developing or reproducing the -# software by the user in light of its specific status of free software, -# that may mean that it is complicated to manipulate, and that also -# therefore means that it is reserved for developers and experienced -# professionals having in-depth computer knowledge. Users are therefore -# encouraged to load and test the software's suitability as regards their -# requirements in conditions enabling the security of their systems and/or -# data to be ensured and, more generally, to use and operate it in the -# same conditions as regards security. -# -# The fact that you are presently reading this means that you have had -# knowledge of the CeCILL-C license and that you accept its terms. -import numpy.random - -from basic.generic_dataset_manager import DatasetManager, GenericDataset -from basic.utils import pad_images, pad_image_width_right, resize_max, pad_image_width_random, pad_sequences_1D, pad_image_height_random, pad_image_width_left, pad_image -from basic.utils import randint, rand, rand_uniform -from basic.generic_dataset_manager import apply_preprocessing -from Datasets.dataset_formatters.read2016_formatter import SEM_MATCHING_TOKENS as READ_MATCHING_TOKENS -from Datasets.dataset_formatters.rimes_formatter import order_text_regions as order_text_regions_rimes -from Datasets.dataset_formatters.rimes_formatter import SEM_MATCHING_TOKENS as RIMES_MATCHING_TOKENS -from Datasets.dataset_formatters.rimes_formatter import SEM_MATCHING_TOKENS_STR as RIMES_MATCHING_TOKENS_STR -from dan.ocr_utils import LM_str_to_ind -import random -import cv2 -import os -import copy -import pickle -import numpy as np -import torch -import matplotlib -from PIL import Image, ImageDraw, ImageFont -from basic.transforms import RandomRotation, apply_transform, Tightening -from fontTools.ttLib import TTFont -from fontTools.unicode import Unicode - - -class OCRDatasetManager(DatasetManager): - """ - Specific class to handle OCR/HTR tasks - """ - - def __init__(self, params): - super(OCRDatasetManager, self).__init__(params) - - self.charset = params["charset"] if "charset" in params else self.get_merged_charsets() - - if "synthetic_data" in self.params["config"] and self.params["config"]["synthetic_data"] and "config" in self.params["config"]["synthetic_data"]: - self.char_only_set = self.charset.copy() - for token_dict in [RIMES_MATCHING_TOKENS, READ_MATCHING_TOKENS]: - for key in token_dict: - if key in self.char_only_set: - self.char_only_set.remove(key) - if token_dict[key] in self.char_only_set: - self.char_only_set.remove(token_dict[key]) - for token in ["\n", ]: - if token in self.char_only_set: - self.char_only_set.remove(token) - self.params["config"]["synthetic_data"]["config"]["valid_fonts"] = get_valid_fonts(self.char_only_set) - - if "new_tokens" in params: - self.charset = sorted(list(set(self.charset).union(set(params["new_tokens"])))) - - self.tokens = { - "pad": params["config"]["padding_token"], - } - if self.params["config"]["charset_mode"].lower() == "ctc": - self.tokens["blank"] = len(self.charset) - self.tokens["pad"] = self.tokens["pad"] if self.tokens["pad"] else len(self.charset) + 1 - self.params["config"]["padding_token"] = self.tokens["pad"] - elif self.params["config"]["charset_mode"] == "seq2seq": - self.tokens["end"] = len(self.charset) - self.tokens["start"] = len(self.charset) + 1 - self.tokens["pad"] = self.tokens["pad"] if self.tokens["pad"] else len(self.charset) + 2 - self.params["config"]["padding_token"] = self.tokens["pad"] - - def get_merged_charsets(self): - """ - Merge the charset of the different datasets used - """ - datasets = self.params["datasets"] - charset = set() - for key in datasets.keys(): - with open(os.path.join(datasets[key], "labels.pkl"), "rb") as f: - info = pickle.load(f) - charset = charset.union(set(info["charset"])) - if "\n" in charset and "remove_linebreaks" in self.params["config"]["constraints"]: - charset.remove("\n") - if "" in charset: - charset.remove("") - return sorted(list(charset)) - - def apply_specific_treatment_after_dataset_loading(self, dataset): - dataset.charset = self.charset - dataset.tokens = self.tokens - dataset.convert_labels() - if "READ_2016" in dataset.name and "augmentation" in dataset.params["config"] and dataset.params["config"]["augmentation"]: - dataset.params["config"]["augmentation"]["fill_value"] = tuple([int(i) for i in dataset.mean]) - if "padding" in dataset.params["config"] and dataset.params["config"]["padding"]["min_height"] == "max": - dataset.params["config"]["padding"]["min_height"] = max([s["img"].shape[0] for s in self.train_dataset.samples]) - if "padding" in dataset.params["config"] and dataset.params["config"]["padding"]["min_width"] == "max": - dataset.params["config"]["padding"]["min_width"] = max([s["img"].shape[1] for s in self.train_dataset.samples]) - - -class OCRDataset(GenericDataset): - """ - Specific class to handle OCR/HTR datasets - """ - - def __init__(self, params, set_name, custom_name, paths_and_sets): - super(OCRDataset, self).__init__(params, set_name, custom_name, paths_and_sets) - self.charset = None - self.tokens = None - self.reduce_dims_factor = np.array([params["config"]["height_divisor"], params["config"]["width_divisor"], 1]) - self.collate_function = OCRCollateFunction - self.synthetic_id = 0 - - def __getitem__(self, idx): - sample = copy.deepcopy(self.samples[idx]) - - if not self.load_in_memory: - sample["img"] = self.get_sample_img(idx) - sample = apply_preprocessing(sample, self.params["config"]["preprocessings"]) - - if "synthetic_data" in self.params["config"] and self.params["config"]["synthetic_data"] and self.set_name == "train": - sample = self.generate_synthetic_data(sample) - - # Data augmentation - sample["img"], sample["applied_da"] = self.apply_data_augmentation(sample["img"]) - - if "max_size" in self.params["config"] and self.params["config"]["max_size"]: - max_ratio = max(sample["img"].shape[0] / self.params["config"]["max_size"]["max_height"], sample["img"].shape[1] / self.params["config"]["max_size"]["max_width"]) - if max_ratio > 1: - new_h, new_w = int(np.ceil(sample["img"].shape[0] / max_ratio)), int(np.ceil(sample["img"].shape[1] / max_ratio)) - sample["img"] = cv2.resize(sample["img"], (new_w, new_h)) - - # Normalization if requested - if "normalize" in self.params["config"] and self.params["config"]["normalize"]: - sample["img"] = (sample["img"] - self.mean) / self.std - - sample["img_shape"] = sample["img"].shape - sample["img_reduced_shape"] = np.ceil(sample["img_shape"] / self.reduce_dims_factor).astype(int) - - # Padding to handle CTC requirements - if self.set_name == "train": - max_label_len = 0 - height = 1 - ctc_padding = False - if "CTC_line" in self.params["config"]["constraints"]: - max_label_len = sample["label_len"] - ctc_padding = True - if "CTC_va" in self.params["config"]["constraints"]: - max_label_len = max(sample["line_label_len"]) - ctc_padding = True - if "CTC_pg" in self.params["config"]["constraints"]: - max_label_len = sample["label_len"] - height = max(sample["img_reduced_shape"][0], 1) - ctc_padding = True - if ctc_padding and 2 * max_label_len + 1 > sample["img_reduced_shape"][1]*height: - sample["img"] = pad_image_width_right(sample["img"], int(np.ceil((2 * max_label_len + 1) / height) * self.reduce_dims_factor[1]), self.padding_value) - sample["img_shape"] = sample["img"].shape - sample["img_reduced_shape"] = np.ceil(sample["img_shape"] / self.reduce_dims_factor).astype(int) - sample["img_reduced_shape"] = [max(1, t) for t in sample["img_reduced_shape"]] - - sample["img_position"] = [[0, sample["img_shape"][0]], [0, sample["img_shape"][1]]] - # Padding constraints to handle model needs - if "padding" in self.params["config"] and self.params["config"]["padding"]: - if self.set_name == "train" or not self.params["config"]["padding"]["train_only"]: - min_pad = self.params["config"]["padding"]["min_pad"] - max_pad = self.params["config"]["padding"]["max_pad"] - pad_width = randint(min_pad, max_pad) if min_pad is not None and max_pad is not None else None - pad_height = randint(min_pad, max_pad) if min_pad is not None and max_pad is not None else None - - sample["img"], sample["img_position"] = pad_image(sample["img"], padding_value=self.padding_value, - new_width=self.params["config"]["padding"]["min_width"], - new_height=self.params["config"]["padding"]["min_height"], - pad_width=pad_width, - pad_height=pad_height, - padding_mode=self.params["config"]["padding"]["mode"], - return_position=True) - sample["img_reduced_position"] = [np.ceil(p / factor).astype(int) for p, factor in zip(sample["img_position"], self.reduce_dims_factor[:2])] - return sample - - - def get_charset(self): - charset = set() - for i in range(len(self.samples)): - charset = charset.union(set(self.samples[i]["label"])) - return charset - - def convert_labels(self): - """ - Label str to token at character level - """ - for i in range(len(self.samples)): - self.samples[i] = self.convert_sample_labels(self.samples[i]) - - def convert_sample_labels(self, sample): - label = sample["label"] - line_labels = label.split("\n") - if "remove_linebreaks" in self.params["config"]["constraints"]: - full_label = label.replace("\n", " ").replace(" ", " ") - word_labels = full_label.split(" ") - else: - full_label = label - word_labels = label.replace("\n", " ").replace(" ", " ").split(" ") - - sample["label"] = full_label - sample["token_label"] = LM_str_to_ind(self.charset, full_label) - if "add_eot" in self.params["config"]["constraints"]: - sample["token_label"].append(self.tokens["end"]) - sample["label_len"] = len(sample["token_label"]) - if "add_sot" in self.params["config"]["constraints"]: - sample["token_label"].insert(0, self.tokens["start"]) - - sample["line_label"] = line_labels - sample["token_line_label"] = [LM_str_to_ind(self.charset, l) for l in line_labels] - sample["line_label_len"] = [len(l) for l in line_labels] - sample["nb_lines"] = len(line_labels) - - sample["word_label"] = word_labels - sample["token_word_label"] = [LM_str_to_ind(self.charset, l) for l in word_labels] - sample["word_label_len"] = [len(l) for l in word_labels] - sample["nb_words"] = len(word_labels) - return sample - - def generate_synthetic_data(self, sample): - config = self.params["config"]["synthetic_data"] - - if not (config["init_proba"] == config["end_proba"] == 1): - nb_samples = self.training_info["step"] * self.params["batch_size"] - if config["start_scheduler_at_max_line"]: - max_step = config["num_steps_proba"] - current_step = max(0, min(nb_samples-config["curr_step"]*(config["max_nb_lines"]-config["min_nb_lines"]), max_step)) - proba = config["init_proba"] if self.get_syn_max_lines() < config["max_nb_lines"] else \ - config["proba_scheduler_function"](config["init_proba"], config["end_proba"], current_step, max_step) - else: - proba = config["proba_scheduler_function"](config["init_proba"], config["end_proba"], - min(nb_samples, config["num_steps_proba"]), - config["num_steps_proba"]) - if rand() > proba: - return sample - - if "mode" in config and config["mode"] == "line_hw_to_printed": - sample["img"] = self.generate_typed_text_line_image(sample["label"]) - return sample - - return self.generate_synthetic_page_sample() - - def get_syn_max_lines(self): - config = self.params["config"]["synthetic_data"] - if config["curriculum"]: - nb_samples = self.training_info["step"]*self.params["batch_size"] - max_nb_lines = min(config["max_nb_lines"], (nb_samples-config["curr_start"]) // config["curr_step"]+1) - return max(config["min_nb_lines"], max_nb_lines) - return config["max_nb_lines"] - - def generate_synthetic_page_sample(self): - config = self.params["config"]["synthetic_data"] - max_nb_lines_per_page = self.get_syn_max_lines() - crop = config["crop_curriculum"] and max_nb_lines_per_page < config["max_nb_lines"] - sample = { - "name": "synthetic_data_{}".format(self.synthetic_id), - "path": None - } - self.synthetic_id += 1 - nb_pages = 2 if "double" in config["dataset_level"] else 1 - background_sample = copy.deepcopy(self.samples[randint(0, len(self))]) - pages = list() - backgrounds = list() - - h, w, c = background_sample["img"].shape - page_width = w // 2 if nb_pages == 2 else w - for i in range(nb_pages): - nb_lines_per_page = randint(config["min_nb_lines"], max_nb_lines_per_page+1) - background = np.ones((h, page_width, c), dtype=background_sample["img"].dtype) * 255 - if i == 0 and nb_pages == 2: - background[:, -2:, :] = 0 - backgrounds.append(background) - if "READ_2016" in self.params["datasets"].keys(): - side = background_sample["pages_label"][i]["side"] - coords = { - "left": int(0.15 * page_width) if side == "left" else int(0.05 * page_width), - "right": int(0.95 * page_width) if side == "left" else int(0.85 * page_width), - "top": int(0.05 * h), - "bottom": int(0.85 * h), - } - pages.append(self.generate_synthetic_read2016_page(background, coords, side=side, crop=crop, - nb_lines=nb_lines_per_page)) - elif "RIMES" in self.params["datasets"].keys(): - pages.append(self.generate_synthetic_rimes_page(background, nb_lines=nb_lines_per_page, crop=crop)) - else: - raise NotImplementedError - - if nb_pages == 1: - sample["img"] = pages[0][0] - sample["label_raw"] = pages[0][1]["raw"] - sample["label_begin"] = pages[0][1]["begin"] - sample["label_sem"] = pages[0][1]["sem"] - sample["label"] = pages[0][1] - sample["nb_cols"] = pages[0][2] - else: - if pages[0][0].shape[0] != pages[1][0].shape[0]: - max_height = max(pages[0][0].shape[0], pages[1][0].shape[0]) - backgrounds[0] = backgrounds[0][:max_height] - backgrounds[0][:pages[0][0].shape[0]] = pages[0][0] - backgrounds[1] = backgrounds[1][:max_height] - backgrounds[1][:pages[1][0].shape[0]] = pages[1][0] - pages[0][0] = backgrounds[0] - pages[1][0] = backgrounds[1] - sample["label_raw"] = pages[0][1]["raw"] + "\n" + pages[1][1]["raw"] - sample["label_begin"] = pages[0][1]["begin"] + pages[1][1]["begin"] - sample["label_sem"] = pages[0][1]["sem"] + pages[1][1]["sem"] - sample["img"] = np.concatenate([pages[0][0], pages[1][0]], axis=1) - sample["nb_cols"] = pages[0][2] + pages[1][2] - sample["label"] = sample["label_raw"] - if "â“‘" in self.charset: - sample["label"] = sample["label_begin"] - if "â’·" in self.charset: - sample["label"] = sample["label_sem"] - sample["unchanged_label"] = sample["label"] - sample = self.convert_sample_labels(sample) - return sample - - def generate_synthetic_rimes_page(self, background, nb_lines=20, crop=False): - max_nb_lines = self.get_syn_max_lines() - def larger_lines(label): - lines = label.split("\n") - new_lines = list() - while len(lines) > 0: - if len(lines) == 1: - new_lines.append(lines[0]) - del lines[0] - elif len(lines[0]) + len(lines[1]) < max_len: - new_lines.append("{} {}".format(lines[0], lines[1])) - del lines[1] - del lines[0] - else: - new_lines.append(lines[0]) - del lines[0] - return "\n".join(new_lines) - config = self.params["config"]["synthetic_data"] - max_len = 100 - matching_tokens = RIMES_MATCHING_TOKENS - matching_tokens_str = RIMES_MATCHING_TOKENS_STR - h, w, c = background.shape - num_lines = list() - for s in self.samples: - l = sum([len(p["label"].split("\n")) for p in s["paragraphs_label"]]) - num_lines.append(l) - stats = self.stat_sem_rimes() - ordered_modes = ['Corps de texte', 'PS/PJ', 'Ouverture', 'Date, Lieu', 'Coordonnées Expéditeur', 'Coordonnées Destinataire', ] - object_ref = ['Objet', 'Reference'] - random.shuffle(object_ref) - ordered_modes = ordered_modes[:3] + object_ref + ordered_modes[3:] - kept_modes = list() - for mode in ordered_modes: - if rand_uniform(0, 1) < stats[mode]: - kept_modes.append(mode) - - paragraphs = dict() - for mode in kept_modes: - paragraphs[mode] = self.get_paragraph_rimes(mode=mode, mix=True) - # proba to merge multiple body textual contents - if mode == "Corps de texte" and rand_uniform(0, 1) < 0.2: - nb_lines = min(nb_lines+10, max_nb_lines) if max_nb_lines < 30 else nb_lines+10 - concat_line = randint(0, 2) == 0 - if concat_line: - paragraphs[mode]["label"] = larger_lines(paragraphs[mode]["label"]) - while (len(paragraphs[mode]["label"].split("\n")) <= 30): - body2 = self.get_paragraph_rimes(mode=mode, mix=True) - paragraphs[mode]["label"] += "\n" + larger_lines(body2["label"]) if concat_line else body2["label"] - paragraphs[mode]["label"] = "\n".join(paragraphs[mode]["label"].split("\n")[:40]) - # proba to set whole text region to uppercase - if rand_uniform(0, 1) < 0.1 and "Corps de texte" in paragraphs: - paragraphs["Corps de texte"]["label"] = paragraphs["Corps de texte"]["label"].upper().replace("È", "E").replace("Ë", "E").replace("Û", "U").replace("Ù", "U").replace("ÃŽ", "I").replace("Ã", "I").replace("Â", "A").replace("Å’", "OE") - # proba to duplicate a line and place it randomly elsewhere, in a body region - if rand_uniform(0, 1) < 0.1 and "Corps de texte" in paragraphs: - labels = paragraphs["Corps de texte"]["label"].split("\n") - duplicated_label = labels[randint(0, len(labels))] - labels.insert(randint(0, len(labels)), duplicated_label) - paragraphs["Corps de texte"]["label"] = "\n".join(labels) - # proba to merge successive lines to have longer text lines in body - if rand_uniform(0, 1) < 0.1 and "Corps de texte" in paragraphs: - paragraphs["Corps de texte"]["label"] = larger_lines(paragraphs["Corps de texte"]["label"]) - for mode in paragraphs.keys(): - line_labels = paragraphs[mode]["label"].split("\n") - if len(line_labels) == 0: - print("ERROR") - paragraphs[mode]["lines"] = list() - for line_label in line_labels: - if len(line_label) > 100: - for chunk in [line_label[i:i + max_len] for i in range(0, len(line_label), max_len)]: - paragraphs[mode]["lines"].append(chunk) - else: - paragraphs[mode]["lines"].append(line_label) - page_labels = { - "raw": "", - "begin": "", - "sem": "" - } - top_limit = 0 - bottom_limit = h - max_bottom_crop = 0 - min_top_crop = h - has_opening = has_object = has_reference = False - top_opening = top_object = top_reference = 0 - right_opening = right_object = right_reference = 0 - has_reference = False - date_on_top = False - date_alone = False - for mode in kept_modes: - pg = paragraphs[mode] - if len(pg["lines"]) > nb_lines: - pg["lines"] = pg["lines"][:nb_lines] - nb_lines -= len(pg["lines"]) - pg_image = self.generate_typed_text_paragraph_image(pg["lines"], padding_value=255, max_pad_left_ratio=1, same_font_size=True) - # proba to remove some interline spacing - if rand_uniform(0, 1) < 0.1: - pg_image = apply_transform(pg_image, Tightening(color=255, remove_proba=0.75)) - # proba to rotate text region - if rand_uniform(0, 1) < 0.1: - pg_image = apply_transform(pg_image, RandomRotation(degrees=10, expand=True, fill=255)) - pg["added"] = True - if mode == 'Corps de texte': - pg_image = resize_max(pg_image, max_height=int(0.5*h), max_width=w) - img_h, img_w = pg_image.shape[:2] - min_top = int(0.4*h) - max_top = int(0.9*h - img_h) - top = randint(min_top, max_top + 1) - left = randint(0, int(w - img_w) + 1) - bottom_body = top + img_h - top_body = top - bottom_limit = min(top, bottom_limit) - elif mode == "PS/PJ": - pg_image = resize_max(pg_image, max_height=int(0.03*h), max_width=int(0.9*w)) - img_h, img_w = pg_image.shape[:2] - min_top = bottom_body - max_top = int(min(h - img_h, bottom_body + 0.15*h)) - top = randint(min_top, max_top + 1) - left = randint(0, int(w - img_w) + 1) - bottom_limit = min(top, bottom_limit) - elif mode == "Ouverture": - pg_image = resize_max(pg_image, max_height=int(0.03 * h), max_width=int(0.9 * w)) - img_h, img_w = pg_image.shape[:2] - min_top = int(top_body - 0.05 * h) - max_top = top_body - img_h - top = randint(min_top, max_top + 1) - left = randint(0, min(int(0.15*w), int(w - img_w)) + 1) - has_opening = True - top_opening = top - right_opening = left + img_w - bottom_limit = min(top, bottom_limit) - elif mode == "Objet": - pg_image = resize_max(pg_image, max_height=int(0.03 * h), max_width=int(0.9 * w)) - img_h, img_w = pg_image.shape[:2] - max_top = top_reference - img_h if has_reference else top_opening - img_h if has_opening else top_body - img_h - min_top = int(max_top - 0.05 * h) - top = randint(min_top, max_top + 1) - left = randint(0, min(int(0.15*w), int(w - img_w)) + 1) - has_object = True - top_object = top - right_object = left + img_w - bottom_limit = min(top, bottom_limit) - elif mode == "Reference": - pg_image = resize_max(pg_image, max_height=int(0.03 * h), max_width=int(0.9 * w)) - img_h, img_w = pg_image.shape[:2] - max_top = top_object - img_h if has_object else top_opening - img_h if has_opening else top_body - img_h - min_top = int(max_top - 0.05 * h) - top = randint(min_top, max_top + 1) - left = randint(0, min(int(0.15*w), int(w - img_w)) + 1) - has_reference = True - top_reference = top - right_reference = left + img_w - bottom_limit = min(top, bottom_limit) - elif mode == 'Date, Lieu': - pg_image = resize_max(pg_image, max_height=int(0.03 * h), max_width=int(0.45 * w)) - img_h, img_w = pg_image.shape[:2] - if h - max_bottom_crop - 10 > img_h and randint(0, 10) == 0: - top = randint(max_bottom_crop, h) - left = randint(0, w-img_w) - else: - min_top = top_body - img_h - max_top = top_body - img_h - min_left = 0 - # Check if there is anough place to put the date at the right side of opening, reference or object - if object_ref == ['Objet', 'Reference']: - have = [has_opening, has_object, has_reference] - rights = [right_opening, right_object, right_reference] - tops = [top_opening, top_object, top_reference] - else: - have = [has_opening, has_reference, has_object] - rights = [right_opening, right_reference, right_object] - tops = [top_opening, top_reference, top_object] - for right_r, top_r, has_r in zip(rights, tops, have): - if has_r: - if right_r + img_w >= 0.95*w: - max_top = min(top_r - img_h, max_top) - min_left = 0 - else: - min_left = max(min_left, right_r+0.05*w) - min_top = top_r - img_h if min_top == top_body - img_h else min_top - if min_left != 0 and randint(0, 5) == 0: - min_left = 0 - for right_r, top_r, has_r in zip(rights, tops, have): - if has_r: - max_top = min(max_top, top_r-img_h) - - max_left = max(min_left, w - img_w) - - # No placement found at right-side of opening, reference or object - if min_left == 0: - # place on the top - if randint(0, 2) == 0: - min_top = 0 - max_top = int(min(0.05*h, max_top)) - date_on_top = True - # place just before object/reference/opening - else: - min_top = int(max(0, max_top - 0.05*h)) - date_alone = True - max_left = min(max_left, int(0.1*w)) - - min_top = min(min_top, max_top) - top = randint(min_top, max_top + 1) - left = randint(int(min_left), max_left + 1) - if date_on_top: - top_limit = max(top_limit, top + img_h) - else: - bottom_limit = min(top, bottom_limit) - date_right = left + img_w - date_bottom = top + img_h - elif mode == "Coordonnées Expéditeur": - max_height = min(0.25*h, bottom_limit-top_limit) - if max_height <= 0: - pg["added"] = False - print("ko", bottom_limit, top_limit) - break - pg_image = resize_max(pg_image, max_height=int(max_height), max_width=int(0.45 * w)) - img_h, img_w = pg_image.shape[:2] - top = randint(top_limit, bottom_limit-img_h+1) - left = randint(0, int(0.5*w-img_w)+1) - elif mode == "Coordonnées Destinataire": - if h - max_bottom_crop - 10 > 0.2*h and randint(0, 10) == 0: - pg_image = resize_max(pg_image, max_height=int(0.2*h), max_width=int(0.45 * w)) - img_h, img_w = pg_image.shape[:2] - top = randint(max_bottom_crop, h) - left = randint(0, w-img_w) - else: - max_height = min(0.25*h, bottom_limit-top_limit) - if max_height <= 0: - pg["added"] = False - print("ko", bottom_limit, top_limit) - break - pg_image = resize_max(pg_image, max_height=int(max_height), max_width=int(0.45 * w)) - img_h, img_w = pg_image.shape[:2] - if date_alone and w - date_right - img_w > 11: - top = randint(0, date_bottom-img_h+1) - left = randint(max(int(0.5*w), date_right+10), w-img_w) - else: - top = randint(top_limit, bottom_limit-img_h+1) - left = randint(int(0.5*w), int(w - img_w)+1) - - bottom = top+img_h - right = left+img_w - min_top_crop = min(top, min_top_crop) - max_bottom_crop = max(bottom, max_bottom_crop) - try: - background[top:bottom, left:right, ...] = pg_image - except: - pg["added"] = False - nb_lines = 0 - pg["coords"] = { - "top": top, - "bottom": bottom, - "right": right, - "left": left - } - - if nb_lines <= 0: - break - sorted_pg = order_text_regions_rimes(paragraphs.values()) - for pg in sorted_pg: - if "added" in pg.keys() and pg["added"]: - pg_label = "\n".join(pg["lines"]) - mode = pg["type"] - begin_token = matching_tokens_str[mode] - end_token = matching_tokens[begin_token] - page_labels["raw"] += pg_label - page_labels["begin"] += begin_token + pg_label - page_labels["sem"] += begin_token + pg_label + end_token - if crop: - if min_top_crop > max_bottom_crop: - print("KO - min > MAX") - elif min_top_crop > h: - print("KO - min > h") - else: - background = background[min_top_crop:max_bottom_crop] - return [background, page_labels, 1] - - def stat_sem_rimes(self): - try: - return self.rimes_sem_stats - except: - stats = dict() - for sample in self.samples: - for pg in sample["paragraphs_label"]: - mode = pg["type"] - if mode == 'Coordonnées Expéditeur': - if len(pg["label"]) < 50 and "\n" not in pg["label"]: - mode = "Reference" - if mode not in stats.keys(): - stats[mode] = 0 - else: - stats[mode] += 1 - for key in stats: - stats[key] = max(0.10, stats[key]/len(self.samples)) - self.rimes_sem_stats = stats - return stats - - def get_paragraph_rimes(self, mode="Corps de texte", mix=False): - while True: - sample = self.samples[randint(0, len(self))] - random.shuffle(sample["paragraphs_label"]) - for pg in sample["paragraphs_label"]: - pg_mode = pg["type"] - if pg_mode == 'Coordonnées Expéditeur': - if len(pg["label"]) < 50 and "\n" not in pg["label"]: - pg_mode = "Reference" - if mode == pg_mode: - if mode == "Corps de texte" and mix: - return self.get_mix_paragraph_rimes(mode, min(5, len(pg["label"].split("\n")))) - else: - return pg - - def get_mix_paragraph_rimes(self, mode="Corps de texte", num_lines=10): - res = list() - while len(res) != num_lines: - sample = self.samples[randint(0, len(self))] - random.shuffle(sample["paragraphs_label"]) - for pg in sample["paragraphs_label"]: - pg_mode = pg["type"] - if pg_mode == 'Coordonnées Expéditeur': - if len(pg["label"]) < 50 and "\n" not in pg["label"]: - pg_mode = "Reference" - if mode == pg_mode: - lines = pg["label"].split("\n") - res.append(lines[randint(0, len(lines))]) - break - return { - "label": "\n".join(res), - "type": mode, - } - - def generate_synthetic_read2016_page(self, background, coords, side="left", nb_lines=20, crop=False): - config = self.params["config"]["synthetic_data"] - two_column = False - matching_token = READ_MATCHING_TOKENS - page_labels = { - "raw": "", - "begin": "â“Ÿ", - "sem": "â“Ÿ", - } - area_top = coords["top"] - area_left = coords["left"] - area_right = coords["right"] - area_bottom = coords["bottom"] - num_page_text_label = str(randint(0, 1000)) - num_page_img = self.generate_typed_text_line_image(num_page_text_label) - - if side == "left": - background[area_top:area_top+num_page_img.shape[0], area_left:area_left+num_page_img.shape[1]] = num_page_img - else: - background[area_top:area_top + num_page_img.shape[0], area_right-num_page_img.shape[1]:area_right] = num_page_img - for key in ["sem", "begin"]: - page_labels[key] += "â“" - for key in page_labels.keys(): - page_labels[key] += num_page_text_label - page_labels["sem"] += matching_token["â“"] - nb_lines -= 1 - area_top = area_top + num_page_img.shape[0] + randint(1, 20) - ratio_ann = rand_uniform(0.6, 0.7) - while nb_lines > 0: - nb_body_lines = randint(1, nb_lines+1) - max_ann_lines = min(nb_body_lines, nb_lines-nb_body_lines) - body_labels = list() - body_imgs = list() - while nb_body_lines > 0: - current_nb_lines = 1 - label, img = self.get_printed_line_read_2016("body") - - nb_body_lines -= current_nb_lines - body_labels.append(label) - body_imgs.append(img) - nb_ann_lines = randint(0, min(6, max_ann_lines+1)) - ann_labels = list() - ann_imgs = list() - while nb_ann_lines > 0: - current_nb_lines = 1 - label, img = self.get_printed_line_read_2016("annotation") - nb_ann_lines -= current_nb_lines - ann_labels.append(label) - ann_imgs.append(img) - max_width_body = int(np.floor(ratio_ann*(area_right-area_left))) - max_width_ann = area_right-area_left-max_width_body - for img_list, max_width in zip([body_imgs, ann_imgs], [max_width_body, max_width_ann]): - for i in range(len(img_list)): - if img_list[i].shape[1] > max_width: - ratio = max_width/img_list[i].shape[1] - new_h = int(np.floor(ratio*img_list[i].shape[0])) - new_w = int(np.floor(ratio*img_list[i].shape[1])) - img_list[i] = cv2.resize(img_list[i], (new_w, new_h), interpolation=cv2.INTER_LINEAR) - body_top = area_top - body_height = 0 - i_body = 0 - for (label, img) in zip(body_labels, body_imgs): - remaining_height = area_bottom - body_top - if img.shape[0] > remaining_height: - nb_lines = 0 - break - background[body_top:body_top+img.shape[0], area_left+max_width_ann:area_left+max_width_ann+img.shape[1]] = img - body_height += img.shape[0] - body_top += img.shape[0] - nb_lines -= 1 - i_body += 1 - - ann_height = int(np.sum([img.shape[0] for img in ann_imgs])) - ann_top = area_top + randint(0, body_height-ann_height+1) if ann_height < body_height else area_top - largest_ann = max([a.shape[1] for a in ann_imgs]) if len(ann_imgs) > 0 else max_width_ann - pad_ann = randint(0, max_width_ann-largest_ann+1) if max_width_ann > largest_ann else 0 - - ann_label_blocks = [list(), ] - i_ann = 0 - ann_height = 0 - for (label, img) in zip(ann_labels, ann_imgs): - remaining_height = body_top - ann_top - if img.shape[0] > remaining_height: - break - background[ann_top:ann_top+img.shape[0], area_left+pad_ann:area_left+pad_ann+img.shape[1]] = img - ann_height += img.shape[0] - ann_top += img.shape[0] - nb_lines -= 1 - two_column = True - ann_label_blocks[-1].append(ann_labels[i_ann]) - i_ann += 1 - if randint(0, 10) == 0: - ann_label_blocks.append(list()) - ann_top += randint(0, max(15, body_top-ann_top-20)) - - area_top = area_top + max(ann_height, body_height) + randint(25, 100) - - ann_full_labels = { - "raw": "", - "begin": "", - "sem": "", - } - for ann_label_block in ann_label_blocks: - if len(ann_label_block) > 0: - for key in ["sem", "begin"]: - ann_full_labels[key] += "â“" - ann_full_labels["raw"] += "\n" - for key in ann_full_labels.keys(): - ann_full_labels[key] += "\n".join(ann_label_block) - ann_full_labels["sem"] += matching_token["â“"] - - body_full_labels = { - "raw": "", - "begin": "", - "sem": "", - } - if i_body > 0: - for key in ["sem", "begin"]: - body_full_labels[key] += "â“‘" - body_full_labels["raw"] += "\n" - for key in body_full_labels.keys(): - body_full_labels[key] += "\n".join(body_labels[:i_body]) - body_full_labels["sem"] += matching_token["â“‘"] - - section_labels = dict() - for key in ann_full_labels.keys(): - section_labels[key] = ann_full_labels[key] + body_full_labels[key] - for key in section_labels.keys(): - if section_labels[key] != "": - if key in ["sem", "begin"]: - section_labels[key] = "â“¢" + section_labels[key] - if key == "sem": - section_labels[key] = section_labels[key] + matching_token["â“¢"] - for key in page_labels.keys(): - page_labels[key] += section_labels[key] - - if crop: - background = background[:area_top] - - page_labels["sem"] += matching_token["â“Ÿ"] - - for key in page_labels.keys(): - page_labels[key] = page_labels[key].strip() - - return [background, page_labels, 2 if two_column else 1] - - def get_n_consecutive_lines_read_2016(self, n=1, mode="body"): - while True: - sample = self.samples[randint(0, len(self))] - paragraphs = list() - for page in sample["pages_label"]: - paragraphs.extend(page["paragraphs"]) - random.shuffle(paragraphs) - for pg in paragraphs: - if ((mode == "body" and pg["mode"] == "body") or - (mode == "ann" and pg["mode"] == "annotation")) and len(pg["lines"]) >= n: - line_idx = randint(0, len(pg["lines"])-n+1) - lines = pg["lines"][line_idx:line_idx+n] - label = "\n".join([l["text"] for l in lines]) - top = min([l["top"] for l in lines]) - bottom = max([l["bottom"] for l in lines]) - left = min([l["left"] for l in lines]) - right = max([l["right"] for l in lines]) - img = sample["img"][top:bottom, left:right] - return label, img - - def get_printed_line_read_2016(self, mode="body"): - while True: - sample = self.samples[randint(0, len(self))] - for page in sample["pages_label"]: - paragraphs = list() - paragraphs.extend(page["paragraphs"]) - random.shuffle(paragraphs) - for pg in paragraphs: - random.shuffle(pg["lines"]) - for line in pg["lines"]: - if (mode == "body" and len(line["text"]) > 5) or (mode == "annotation" and len(line["text"]) < 15 and not line["text"].isdigit()): - label = line["text"] - img = self.generate_typed_text_line_image(label) - return label, img - - def generate_typed_text_line_image(self, text): - return generate_typed_text_line_image(text, self.params["config"]["synthetic_data"]["config"]) - - def generate_typed_text_paragraph_image(self, texts, padding_value=255, max_pad_left_ratio=0.1, same_font_size=False): - config = self.params["config"]["synthetic_data"]["config"] - if same_font_size: - images = list() - txt_color = config["text_color_default"] - bg_color = config["background_color_default"] - font_size = randint(config["font_size_min"], config["font_size_max"] + 1) - for text in texts: - font_path = config["valid_fonts"][randint(0, len(config["valid_fonts"]))] - fnt = ImageFont.truetype(font_path, font_size) - text_width, text_height = fnt.getsize(text) - padding_top = int(rand_uniform(config["padding_top_ratio_min"], config["padding_top_ratio_max"]) * text_height) - padding_bottom = int(rand_uniform(config["padding_bottom_ratio_min"], config["padding_bottom_ratio_max"]) * text_height) - padding_left = int(rand_uniform(config["padding_left_ratio_min"], config["padding_left_ratio_max"]) * text_width) - padding_right = int(rand_uniform(config["padding_right_ratio_min"], config["padding_right_ratio_max"]) * text_width) - padding = [padding_top, padding_bottom, padding_left, padding_right] - images.append(generate_typed_text_line_image_from_params(text, fnt, bg_color, txt_color, config["color_mode"], padding)) - else: - images = [self.generate_typed_text_line_image(t) for t in texts] - - max_width = max([img.shape[1] for img in images]) - - padded_images = [pad_image_width_random(img, max_width, padding_value=padding_value, max_pad_left_ratio=max_pad_left_ratio) for img in images] - return np.concatenate(padded_images, axis=0) - - - -class OCRCollateFunction: - """ - Merge samples data to mini-batch data for OCR task - """ - - def __init__(self, config): - self.img_padding_value = float(config["padding_value"]) - self.label_padding_value = config["padding_token"] - self.config = config - - def __call__(self, batch_data): - names = [batch_data[i]["name"] for i in range(len(batch_data))] - ids = [batch_data[i]["name"].split("/")[-1].split(".")[0] for i in range(len(batch_data))] - applied_da = [batch_data[i]["applied_da"] for i in range(len(batch_data))] - - labels = [batch_data[i]["token_label"] for i in range(len(batch_data))] - labels = pad_sequences_1D(labels, padding_value=self.label_padding_value) - labels = torch.tensor(labels).long() - reverse_labels = [[batch_data[i]["token_label"][0], ] + batch_data[i]["token_label"][-2:0:-1] + [batch_data[i]["token_label"][-1], ] for i in range(len(batch_data))] - reverse_labels = pad_sequences_1D(reverse_labels, padding_value=self.label_padding_value) - reverse_labels = torch.tensor(reverse_labels).long() - labels_len = [batch_data[i]["label_len"] for i in range(len(batch_data))] - - raw_labels = [batch_data[i]["label"] for i in range(len(batch_data))] - unchanged_labels = [batch_data[i]["unchanged_label"] for i in range(len(batch_data))] - - nb_cols = [batch_data[i]["nb_cols"] for i in range(len(batch_data))] - nb_lines = [batch_data[i]["nb_lines"] for i in range(len(batch_data))] - line_raw = [batch_data[i]["line_label"] for i in range(len(batch_data))] - line_token = [batch_data[i]["token_line_label"] for i in range(len(batch_data))] - pad_line_token = list() - line_len = [batch_data[i]["line_label_len"] for i in range(len(batch_data))] - for i in range(max(nb_lines)): - current_lines = [line_token[j][i] if i < nb_lines[j] else [self.label_padding_value] for j in range(len(batch_data))] - pad_line_token.append(torch.tensor(pad_sequences_1D(current_lines, padding_value=self.label_padding_value)).long()) - for j in range(len(batch_data)): - if i >= nb_lines[j]: - line_len[j].append(0) - line_len = [i for i in zip(*line_len)] - - nb_words = [batch_data[i]["nb_words"] for i in range(len(batch_data))] - word_raw = [batch_data[i]["word_label"] for i in range(len(batch_data))] - word_token = [batch_data[i]["token_word_label"] for i in range(len(batch_data))] - pad_word_token = list() - word_len = [batch_data[i]["word_label_len"] for i in range(len(batch_data))] - for i in range(max(nb_words)): - current_words = [word_token[j][i] if i < nb_words[j] else [self.label_padding_value] for j in range(len(batch_data))] - pad_word_token.append(torch.tensor(pad_sequences_1D(current_words, padding_value=self.label_padding_value)).long()) - for j in range(len(batch_data)): - if i >= nb_words[j]: - word_len[j].append(0) - word_len = [i for i in zip(*word_len)] - - padding_mode = self.config["padding_mode"] if "padding_mode" in self.config else "br" - imgs = [batch_data[i]["img"] for i in range(len(batch_data))] - imgs_shape = [batch_data[i]["img_shape"] for i in range(len(batch_data))] - imgs_reduced_shape = [batch_data[i]["img_reduced_shape"] for i in range(len(batch_data))] - imgs_position = [batch_data[i]["img_position"] for i in range(len(batch_data))] - imgs_reduced_position= [batch_data[i]["img_reduced_position"] for i in range(len(batch_data))] - imgs = pad_images(imgs, padding_value=self.img_padding_value, padding_mode=padding_mode) - imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2) - formatted_batch_data = { - "names": names, - "ids": ids, - "nb_lines": nb_lines, - "nb_cols": nb_cols, - "labels": labels, - "reverse_labels": reverse_labels, - "raw_labels": raw_labels, - "unchanged_labels": unchanged_labels, - "labels_len": labels_len, - "imgs": imgs, - "imgs_shape": imgs_shape, - "imgs_reduced_shape": imgs_reduced_shape, - "imgs_position": imgs_position, - "imgs_reduced_position": imgs_reduced_position, - "line_raw": line_raw, - "line_labels": pad_line_token, - "line_labels_len": line_len, - "nb_words": nb_words, - "word_raw": word_raw, - "word_labels": pad_word_token, - "word_labels_len": word_len, - "applied_da": applied_da - } - - return formatted_batch_data - - -def generate_typed_text_line_image(text, config, bg_color=(255, 255, 255), txt_color=(0, 0, 0)): - if text == "": - text = " " - if "text_color_default" in config: - txt_color = config["text_color_default"] - if "background_color_default" in config: - bg_color = config["background_color_default"] - - font_path = config["valid_fonts"][randint(0, len(config["valid_fonts"]))] - font_size = randint(config["font_size_min"], config["font_size_max"]+1) - fnt = ImageFont.truetype(font_path, font_size) - - text_width, text_height = fnt.getsize(text) - padding_top = int(rand_uniform(config["padding_top_ratio_min"], config["padding_top_ratio_max"])*text_height) - padding_bottom = int(rand_uniform(config["padding_bottom_ratio_min"], config["padding_bottom_ratio_max"])*text_height) - padding_left = int(rand_uniform(config["padding_left_ratio_min"], config["padding_left_ratio_max"])*text_width) - padding_right = int(rand_uniform(config["padding_right_ratio_min"], config["padding_right_ratio_max"])*text_width) - padding = [padding_top, padding_bottom, padding_left, padding_right] - return generate_typed_text_line_image_from_params(text, fnt, bg_color, txt_color, config["color_mode"], padding) - - -def generate_typed_text_line_image_from_params(text, font, bg_color, txt_color, color_mode, padding): - padding_top, padding_bottom, padding_left, padding_right = padding - text_width, text_height = font.getsize(text) - img_height = padding_top + padding_bottom + text_height - img_width = padding_left + padding_right + text_width - img = Image.new(color_mode, (img_width, img_height), color=bg_color) - d = ImageDraw.Draw(img) - d.text((padding_left, padding_bottom), text, font=font, fill=txt_color, spacing=0) - return np.array(img) - - -def get_valid_fonts(alphabet=None): - valid_fonts = list() - for fold_detail in os.walk("../../../Fonts"): - if fold_detail[2]: - for font_name in fold_detail[2]: - if ".ttf" not in font_name: - continue - font_path = os.path.join(fold_detail[0], font_name) - to_add = True - if alphabet is not None: - for char in alphabet: - if not char_in_font(char, font_path): - to_add = False - break - if to_add: - valid_fonts.append(font_path) - else: - valid_fonts.append(font_path) - return valid_fonts - - -def char_in_font(unicode_char, font_path): - with TTFont(font_path) as font: - for cmap in font['cmap'].tables: - if cmap.isUnicode(): - if ord(unicode_char) in cmap.cmap: - return True - return False diff --git a/OCR/ocr_manager.py b/OCR/ocr_manager.py deleted file mode 100644 index e34b36c0a0793868c5a8f590b20510db72bc0681..0000000000000000000000000000000000000000 --- a/OCR/ocr_manager.py +++ /dev/null @@ -1,67 +0,0 @@ -from basic.generic_training_manager import GenericTrainingManager -import os -from PIL import Image -import pickle - - -class OCRManager(GenericTrainingManager): - def __init__(self, params): - super(OCRManager, self).__init__(params) - self.params["model_params"]["vocab_size"] = len(self.dataset.charset) - - def generate_syn_line_dataset(self, name): - """ - Generate synthetic line dataset from currently loaded dataset - """ - dataset_name = list(self.params['dataset_params']["datasets"].keys())[0] - path = os.path.join(os.path.dirname(self.params['dataset_params']["datasets"][dataset_name]), name) - os.makedirs(path, exist_ok=True) - charset = set() - dataset = None - gt = { - "train": dict(), - "valid": dict(), - "test": dict() - } - for set_name in ["train", "valid", "test"]: - set_path = os.path.join(path, set_name) - os.makedirs(set_path, exist_ok=True) - if set_name == "train": - dataset = self.dataset.train_dataset - elif set_name == "valid": - dataset = self.dataset.valid_datasets["{}-valid".format(dataset_name)] - elif set_name == "test": - self.dataset.generate_test_loader("{}-test".format(dataset_name), [(dataset_name, "test"), ]) - dataset = self.dataset.test_datasets["{}-test".format(dataset_name)] - - samples = list() - for sample in dataset.samples: - for line_label in sample["label"].split("\n"): - for chunk in [line_label[i:i+100] for i in range(0, len(line_label), 100)]: - charset = charset.union(set(chunk)) - if len(chunk) > 0: - samples.append({ - "path": sample["path"], - "label": chunk, - "nb_cols": 1, - }) - - for i, sample in enumerate(samples): - ext = sample['path'].split(".")[-1] - img_name = "{}_{}.{}".format(set_name, i, ext) - img_path = os.path.join(set_path, img_name) - - img = dataset.generate_typed_text_line_image(sample["label"]) - Image.fromarray(img).save(img_path) - gt[set_name][img_name] = { - "text": sample["label"], - "nb_cols": sample["nb_cols"] if "nb_cols" in sample else 1 - } - if "line_label" in sample: - gt[set_name][img_name]["lines"] = sample["line_label"] - - with open(os.path.join(path, "labels.pkl"), "wb") as f: - pickle.dump({ - "ground_truth": gt, - "charset": sorted(list(charset)), - }, f) \ No newline at end of file diff --git a/README.md b/README.md index 28d0dc9201d28151f4b494df0fc6685b991d0479..9b29f2040998f2a13c945eb8686490eab4d37069 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,10 @@ This repository is a public implementation of the paper: "DAN: a Segmentation-free Document Attention Network for Handwritten Document Recognition". - + The model uses a character-level attention to handle slanted lines: - + The paper is available at https://arxiv.org/abs/2203.12273. diff --git a/basic/generic_dataset_manager.py b/basic/generic_dataset_manager.py deleted file mode 100644 index dd272d152884b4fc2d27f6467ad915f099f77acc..0000000000000000000000000000000000000000 --- a/basic/generic_dataset_manager.py +++ /dev/null @@ -1,390 +0,0 @@ -import torch -import random -from torch.utils.data import Dataset, DataLoader -from torch.utils.data.distributed import DistributedSampler -from basic.transforms import apply_data_augmentation -from Datasets.dataset_formatters.utils_dataset import natural_sort -import os -import numpy as np -import pickle -from PIL import Image -import cv2 - - -class DatasetManager: - - def __init__(self, params): - self.params = params - self.dataset_class = params["dataset_class"] - self.img_padding_value = params["config"]["padding_value"] - - self.my_collate_function = None - - self.train_dataset = None - self.valid_datasets = dict() - self.test_datasets = dict() - - self.train_loader = None - self.valid_loaders = dict() - self.test_loaders = dict() - - self.train_sampler = None - self.valid_samplers = dict() - self.test_samplers = dict() - - self.generator = torch.Generator() - self.generator.manual_seed(0) - - self.batch_size = { - "train": self.params["batch_size"], - "valid": self.params["valid_batch_size"] if "valid_batch_size" in self.params else self.params["batch_size"], - "test": self.params["test_batch_size"] if "test_batch_size" in self.params else 1, - } - - def apply_specific_treatment_after_dataset_loading(self, dataset): - raise NotImplementedError - - def load_datasets(self): - """ - Load training and validation datasets - """ - self.train_dataset = self.dataset_class(self.params, "train", self.params["train"]["name"], self.get_paths_and_sets(self.params["train"]["datasets"])) - self.params["config"]["mean"], self.params["config"]["std"] = self.train_dataset.compute_std_mean() - - self.my_collate_function = self.train_dataset.collate_function(self.params["config"]) - self.apply_specific_treatment_after_dataset_loading(self.train_dataset) - - for custom_name in self.params["valid"].keys(): - self.valid_datasets[custom_name] = self.dataset_class(self.params, "valid", custom_name, self.get_paths_and_sets(self.params["valid"][custom_name])) - self.apply_specific_treatment_after_dataset_loading(self.valid_datasets[custom_name]) - - def load_ddp_samplers(self): - """ - Load training and validation data samplers - """ - if self.params["use_ddp"]: - self.train_sampler = DistributedSampler(self.train_dataset, num_replicas=self.params["num_gpu"], rank=self.params["ddp_rank"], shuffle=True) - for custom_name in self.valid_datasets.keys(): - self.valid_samplers[custom_name] = DistributedSampler(self.valid_datasets[custom_name], num_replicas=self.params["num_gpu"], rank=self.params["ddp_rank"], shuffle=False) - else: - for custom_name in self.valid_datasets.keys(): - self.valid_samplers[custom_name] = None - - def load_dataloaders(self): - """ - Load training and validation data loaders - """ - self.train_loader = DataLoader(self.train_dataset, - batch_size=self.batch_size["train"], - shuffle=True if self.train_sampler is None else False, - drop_last=False, - batch_sampler=self.train_sampler, - sampler=self.train_sampler, - num_workers=self.params["num_gpu"]*self.params["worker_per_gpu"], - pin_memory=True, - collate_fn=self.my_collate_function, - worker_init_fn=self.seed_worker, - generator=self.generator) - - for key in self.valid_datasets.keys(): - self.valid_loaders[key] = DataLoader(self.valid_datasets[key], - batch_size=self.batch_size["valid"], - sampler=self.valid_samplers[key], - batch_sampler=self.valid_samplers[key], - shuffle=False, - num_workers=self.params["num_gpu"]*self.params["worker_per_gpu"], - pin_memory=True, - drop_last=False, - collate_fn=self.my_collate_function, - worker_init_fn=self.seed_worker, - generator=self.generator) - - @staticmethod - def seed_worker(worker_id): - worker_seed = torch.initial_seed() % 2 ** 32 - np.random.seed(worker_seed) - random.seed(worker_seed) - - def generate_test_loader(self, custom_name, sets_list): - """ - Load test dataset, data sampler and data loader - """ - if custom_name in self.test_loaders.keys(): - return - paths_and_sets = list() - for set_info in sets_list: - paths_and_sets.append({ - "path": self.params["datasets"][set_info[0]], - "set_name": set_info[1] - }) - self.test_datasets[custom_name] = self.dataset_class(self.params, "test", custom_name, paths_and_sets) - self.apply_specific_treatment_after_dataset_loading(self.test_datasets[custom_name]) - if self.params["use_ddp"]: - self.test_samplers[custom_name] = DistributedSampler(self.test_datasets[custom_name], num_replicas=self.params["num_gpu"], rank=self.params["ddp_rank"], shuffle=False) - else: - self.test_samplers[custom_name] = None - self.test_loaders[custom_name] = DataLoader(self.test_datasets[custom_name], - batch_size=self.batch_size["test"], - sampler=self.test_samplers[custom_name], - shuffle=False, - num_workers=self.params["num_gpu"]*self.params["worker_per_gpu"], - pin_memory=True, - drop_last=False, - collate_fn=self.my_collate_function, - worker_init_fn=self.seed_worker, - generator=self.generator) - - def remove_test_dataset(self, custom_name): - del self.test_datasets[custom_name] - del self.test_samplers[custom_name] - del self.test_loaders[custom_name] - - def remove_valid_dataset(self, custom_name): - del self.valid_datasets[custom_name] - del self.valid_samplers[custom_name] - del self.valid_loaders[custom_name] - - def remove_train_dataset(self): - self.train_dataset = None - self.train_sampler = None - self.train_loader = None - - def remove_all_datasets(self): - self.remove_train_dataset() - for name in list(self.valid_datasets.keys()): - self.remove_valid_dataset(name) - for name in list(self.test_datasets.keys()): - self.remove_test_dataset(name) - - def get_paths_and_sets(self, dataset_names_folds): - paths_and_sets = list() - for dataset_name, fold in dataset_names_folds: - path = self.params["datasets"][dataset_name] - paths_and_sets.append({ - "path": path, - "set_name": fold - }) - return paths_and_sets - - -class GenericDataset(Dataset): - """ - Main class to handle dataset loading - """ - - def __init__(self, params, set_name, custom_name, paths_and_sets): - self.params = params - self.name = custom_name - self.set_name = set_name - self.mean = np.array(params["config"]["mean"]) if "mean" in params["config"].keys() else None - self.std = np.array(params["config"]["std"]) if "std" in params["config"].keys() else None - - self.load_in_memory = self.params["config"]["load_in_memory"] if "load_in_memory" in self.params["config"] else True - - self.samples = self.load_samples(paths_and_sets, load_in_memory=self.load_in_memory) - - if self.load_in_memory: - self.apply_preprocessing(params["config"]["preprocessings"]) - - self.padding_value = params["config"]["padding_value"] - if self.padding_value == "mean": - if self.mean is None: - _, _ = self.compute_std_mean() - self.padding_value = self.mean - self.params["config"]["padding_value"] = self.padding_value - - self.curriculum_config = None - self.training_info = None - - def __len__(self): - return len(self.samples) - - @staticmethod - def load_image(path): - with Image.open(path) as pil_img: - img = np.array(pil_img) - ## grayscale images - if len(img.shape) == 2: - img = np.expand_dims(img, axis=2) - return img - - @staticmethod - def load_samples(paths_and_sets, load_in_memory=True): - """ - Load images and labels - """ - samples = list() - for path_and_set in paths_and_sets: - path = path_and_set["path"] - set_name = path_and_set["set_name"] - with open(os.path.join(path, "labels.pkl"), "rb") as f: - info = pickle.load(f) - gt = info["ground_truth"][set_name] - for filename in natural_sort(gt.keys()): - name = os.path.join(os.path.basename(path), set_name, filename) - full_path = os.path.join(path, set_name, filename) - if isinstance(gt[filename], dict) and "text" in gt[filename]: - label = gt[filename]["text"] - else: - label = gt[filename] - samples.append({ - "name": name, - "label": label, - "unchanged_label": label, - "path": full_path, - "nb_cols": 1 if "nb_cols" not in gt[filename] else gt[filename]["nb_cols"] - }) - if load_in_memory: - samples[-1]["img"] = GenericDataset.load_image(full_path) - if type(gt[filename]) is dict: - if "lines" in gt[filename].keys(): - samples[-1]["raw_line_seg_label"] = gt[filename]["lines"] - if "paragraphs" in gt[filename].keys(): - samples[-1]["paragraphs_label"] = gt[filename]["paragraphs"] - if "pages" in gt[filename].keys(): - samples[-1]["pages_label"] = gt[filename]["pages"] - return samples - - def apply_preprocessing(self, preprocessings): - for i in range(len(self.samples)): - self.samples[i] = apply_preprocessing(self.samples[i], preprocessings) - - def compute_std_mean(self): - """ - Compute cumulated variance and mean of whole dataset - """ - if self.mean is not None and self.std is not None: - return self.mean, self.std - if not self.load_in_memory: - sample = self.samples[0].copy() - sample["img"] = self.get_sample_img(0) - img = apply_preprocessing(sample, self.params["config"]["preprocessings"])["img"] - else: - img = self.get_sample_img(0) - _, _, c = img.shape - sum = np.zeros((c,)) - nb_pixels = 0 - - for i in range(len(self.samples)): - if not self.load_in_memory: - sample = self.samples[i].copy() - sample["img"] = self.get_sample_img(i) - img = apply_preprocessing(sample, self.params["config"]["preprocessings"])["img"] - else: - img = self.get_sample_img(i) - sum += np.sum(img, axis=(0, 1)) - nb_pixels += np.prod(img.shape[:2]) - mean = sum / nb_pixels - diff = np.zeros((c,)) - for i in range(len(self.samples)): - if not self.load_in_memory: - sample = self.samples[i].copy() - sample["img"] = self.get_sample_img(i) - img = apply_preprocessing(sample, self.params["config"]["preprocessings"])["img"] - else: - img = self.get_sample_img(i) - diff += [np.sum((img[:, :, k] - mean[k]) ** 2) for k in range(c)] - std = np.sqrt(diff / nb_pixels) - - self.mean = mean - self.std = std - return mean, std - - def apply_data_augmentation(self, img): - """ - Apply data augmentation strategy on the input image - """ - augs = [self.params["config"][key] if key in self.params["config"].keys() else None for key in ["augmentation", "valid_augmentation", "test_augmentation"]] - for aug, set_name in zip(augs, ["train", "valid", "test"]): - if aug and self.set_name == set_name: - return apply_data_augmentation(img, aug) - return img, list() - - def get_sample_img(self, i): - """ - Get image by index - """ - if self.load_in_memory: - return self.samples[i]["img"] - else: - return GenericDataset.load_image(self.samples[i]["path"]) - - def denormalize(self, img): - """ - Get original image, before normalization - """ - return img * self.std + self.mean - - -def apply_preprocessing(sample, preprocessings): - """ - Apply preprocessings on each sample - """ - resize_ratio = [1, 1] - img = sample["img"] - for preprocessing in preprocessings: - - if preprocessing["type"] == "dpi": - ratio = preprocessing["target"] / preprocessing["source"] - temp_img = img - h, w, c = temp_img.shape - temp_img = cv2.resize(temp_img, (int(np.ceil(w * ratio)), int(np.ceil(h * ratio)))) - if len(temp_img.shape) == 2: - temp_img = np.expand_dims(temp_img, axis=2) - img = temp_img - - resize_ratio = [ratio, ratio] - - if preprocessing["type"] == "to_grayscaled": - temp_img = img - h, w, c = temp_img.shape - if c == 3: - img = np.expand_dims( - 0.2125 * temp_img[:, :, 0] + 0.7154 * temp_img[:, :, 1] + 0.0721 * temp_img[:, :, 2], - axis=2).astype(np.uint8) - - if preprocessing["type"] == "to_RGB": - temp_img = img - h, w, c = temp_img.shape - if c == 1: - img = np.concatenate([temp_img, temp_img, temp_img], axis=2) - - if preprocessing["type"] == "resize": - keep_ratio = preprocessing["keep_ratio"] - max_h, max_w = preprocessing["max_height"], preprocessing["max_width"] - temp_img = img - h, w, c = temp_img.shape - - ratio_h = max_h / h if max_h else 1 - ratio_w = max_w / w if max_w else 1 - if keep_ratio: - ratio_h = ratio_w = min(ratio_w, ratio_h) - new_h = min(max_h, int(h * ratio_h)) - new_w = min(max_w, int(w * ratio_w)) - temp_img = cv2.resize(temp_img, (new_w, new_h)) - if len(temp_img.shape) == 2: - temp_img = np.expand_dims(temp_img, axis=2) - - img = temp_img - resize_ratio = [ratio_h, ratio_w] - - if preprocessing["type"] == "fixed_height": - new_h = preprocessing["height"] - temp_img = img - h, w, c = temp_img.shape - ratio = new_h / h - temp_img = cv2.resize(temp_img, (int(w*ratio), new_h)) - if len(temp_img.shape) == 2: - temp_img = np.expand_dims(temp_img, axis=2) - img = temp_img - resize_ratio = [ratio, ratio] - if resize_ratio != [1, 1] and "raw_line_seg_label" in sample: - for li in range(len(sample["raw_line_seg_label"])): - for side, ratio in zip((["bottom", "top"], ["right", "left"]), resize_ratio): - for s in side: - sample["raw_line_seg_label"][li][s] = sample["raw_line_seg_label"][li][s] * ratio - - sample["img"] = img - sample["resize_ratio"] = resize_ratio - return sample - diff --git a/basic/generic_training_manager.py b/basic/generic_training_manager.py deleted file mode 100644 index a8f41b28274678b23083123e33d2343bb6b942b9..0000000000000000000000000000000000000000 --- a/basic/generic_training_manager.py +++ /dev/null @@ -1,706 +0,0 @@ -import torch -import os -import sys -import copy -import json -import torch.distributed as dist -import torch.multiprocessing as mp -import random -import numpy as np -from torch.utils.tensorboard import SummaryWriter -from torch.nn.init import kaiming_uniform_ -from tqdm import tqdm -from time import time -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.cuda.amp import GradScaler -from basic.metric_manager import MetricManager -from basic.scheduler import DropoutScheduler -from datetime import date - - -class GenericTrainingManager: - - def __init__(self, params): - self.type = None - self.is_master = False - self.params = params - self.dropout_scheduler = None - self.models = {} - self.begin_time = None - self.dataset = None - self.dataset_name = list(self.params["dataset_params"]["datasets"].values())[0] - self.paths = None - self.latest_step = 0 - self.latest_epoch = -1 - self.latest_batch = 0 - self.total_batch = 0 - self.grad_acc_step = 0 - self.latest_train_metrics = dict() - self.latest_valid_metrics = dict() - self.curriculum_info = dict() - self.curriculum_info["latest_valid_metrics"] = dict() - self.phase = None - self.max_mem_usage_by_epoch = list() - self.losses = list() - self.lr_values = list() - - self.scaler = None - - self.optimizers = dict() - self.optimizers_named_params_by_group = dict() - self.lr_schedulers = dict() - self.best = None - self.writer = None - self.metric_manager = dict() - - self.init_hardware_config() - self.init_paths() - self.load_dataset() - self.params["model_params"]["use_amp"] = self.params["training_params"]["use_amp"] - - def init_paths(self): - """ - Create output folders for results and checkpoints - """ - output_path = os.path.join("outputs", self.params["training_params"]["output_folder"]) - os.makedirs(output_path, exist_ok=True) - checkpoints_path = os.path.join(output_path, "checkpoints") - os.makedirs(checkpoints_path, exist_ok=True) - results_path = os.path.join(output_path, "results") - os.makedirs(results_path, exist_ok=True) - - self.paths = { - "results": results_path, - "checkpoints": checkpoints_path, - "output_folder": output_path - } - - def load_dataset(self): - """ - Load datasets, data samplers and data loaders - """ - self.params["dataset_params"]["use_ddp"] = self.params["training_params"]["use_ddp"] - self.params["dataset_params"]["batch_size"] = self.params["training_params"]["batch_size"] - if "valid_batch_size" in self.params["training_params"]: - self.params["dataset_params"]["valid_batch_size"] = self.params["training_params"]["valid_batch_size"] - if "test_batch_size" in self.params["training_params"]: - self.params["dataset_params"]["test_batch_size"] = self.params["training_params"]["test_batch_size"] - self.params["dataset_params"]["num_gpu"] = self.params["training_params"]["nb_gpu"] - self.params["dataset_params"]["worker_per_gpu"] = 4 if "worker_per_gpu" not in self.params["dataset_params"] else self.params["dataset_params"]["worker_per_gpu"] - self.dataset = self.params["dataset_params"]["dataset_manager"](self.params["dataset_params"]) - self.dataset.load_datasets() - self.dataset.load_ddp_samplers() - self.dataset.load_dataloaders() - - def init_hardware_config(self): - # Debug mode - if self.params["training_params"]["force_cpu"]: - self.params["training_params"]["use_ddp"] = False - self.params["training_params"]["use_amp"] = False - # Manage Distributed Data Parallel & GPU usage - self.manual_seed = 1111 if "manual_seed" not in self.params["training_params"].keys() else \ - self.params["training_params"]["manual_seed"] - self.ddp_config = { - "master": self.params["training_params"]["use_ddp"] and self.params["training_params"]["ddp_rank"] == 0, - "address": "localhost" if "ddp_addr" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_addr"], - "port": "11111" if "ddp_port" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_port"], - "backend": "nccl" if "ddp_backend" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_backend"], - "rank": self.params["training_params"]["ddp_rank"], - } - self.is_master = self.ddp_config["master"] or not self.params["training_params"]["use_ddp"] - if self.params["training_params"]["force_cpu"]: - self.device = "cpu" - else: - if self.params["training_params"]["use_ddp"]: - self.device = torch.device(self.ddp_config["rank"]) - self.params["dataset_params"]["ddp_rank"] = self.ddp_config["rank"] - self.launch_ddp() - else: - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self.params["model_params"]["device"] = self.device.type - # Print GPU info - # global - if (self.params["training_params"]["use_ddp"] and self.ddp_config["master"]) or not self.params["training_params"]["use_ddp"]: - print("##################") - print("Available GPUS: {}".format(self.params["training_params"]["nb_gpu"])) - for i in range(self.params["training_params"]["nb_gpu"]): - print("Rank {}: {} {}".format(i, torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i))) - print("##################") - # local - print("Local GPU:") - if self.device != "cpu": - print("Rank {}: {} {}".format(self.params["training_params"]["ddp_rank"], torch.cuda.get_device_name(), torch.cuda.get_device_properties(self.device))) - else: - print("WORKING ON CPU !\n") - print("##################") - - def load_model(self, reset_optimizer=False, strict=True): - """ - Load model weights from scratch or from checkpoints - """ - # Instantiate Model - for model_name in self.params["model_params"]["models"].keys(): - self.models[model_name] = self.params["model_params"]["models"][model_name](self.params["model_params"]) - self.models[model_name].to(self.device) # To GPU or CPU - # make the model compatible with Distributed Data Parallel if used - if self.params["training_params"]["use_ddp"]: - self.models[model_name] = DDP(self.models[model_name], [self.ddp_config["rank"]]) - - # Handle curriculum dropout - if "dropout_scheduler" in self.params["model_params"]: - func = self.params["model_params"]["dropout_scheduler"]["function"] - T = self.params["model_params"]["dropout_scheduler"]["T"] - self.dropout_scheduler = DropoutScheduler(self.models, func, T) - - self.scaler = GradScaler(enabled=self.params["training_params"]["use_amp"]) - - # Check if checkpoint exists - checkpoint = self.get_checkpoint() - if checkpoint is not None: - self.load_existing_model(checkpoint, strict=strict) - else: - self.init_new_model() - - self.load_optimizers(checkpoint, reset_optimizer=reset_optimizer) - - if self.is_master: - print("LOADED EPOCH: {}\n".format(self.latest_epoch), flush=True) - - def get_checkpoint(self): - """ - Seek if checkpoint exist, return None otherwise - """ - if self.params["training_params"]["load_epoch"] in ("best", "last"): - for filename in os.listdir(self.paths["checkpoints"]): - if self.params["training_params"]["load_epoch"] in filename: - return torch.load(os.path.join(self.paths["checkpoints"], filename)) - return None - - def load_existing_model(self, checkpoint, strict=True): - """ - Load information and weights from previous training - """ - self.load_save_info(checkpoint) - self.latest_epoch = checkpoint["epoch"] - if "step" in checkpoint: - self.latest_step = checkpoint["step"] - self.best = checkpoint["best"] - if "scaler_state_dict" in checkpoint: - self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) - # Load model weights from past training - for model_name in self.models.keys(): - self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(model_name)], strict=strict) - - def init_new_model(self): - """ - Initialize model - """ - # Specific weights initialization if exists - for model_name in self.models.keys(): - try: - self.models[model_name].init_weights() - except: - pass - - # Handle transfer learning instructions - if self.params["model_params"]["transfer_learning"]: - # Iterates over models - for model_name in self.params["model_params"]["transfer_learning"].keys(): - state_dict_name, path, learnable, strict = self.params["model_params"]["transfer_learning"][model_name] - # Loading pretrained weights file - checkpoint = torch.load(path) - try: - # Load pretrained weights for model - self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(state_dict_name)], strict=strict) - print("transfered weights for {}".format(state_dict_name), flush=True) - except RuntimeError as e: - print(e, flush=True) - # if error, try to load each parts of the model (useful if only few layers are different) - for key in checkpoint["{}_state_dict".format(state_dict_name)].keys(): - try: - # for pre-training of decision layer - if "end_conv" in key and "transfered_charset" in self.params["model_params"]: - self.adapt_decision_layer_to_old_charset(model_name, key, checkpoint, state_dict_name) - else: - self.models[model_name].load_state_dict( - {key: checkpoint["{}_state_dict".format(state_dict_name)][key]}, strict=False) - except RuntimeError as e: - ## exception when adding linebreak token from pretraining - print(e, flush=True) - # Set parameters no trainable - if not learnable: - self.set_model_learnable(self.models[model_name], False) - - def adapt_decision_layer_to_old_charset(self, model_name, key, checkpoint, state_dict_name): - """ - Transfer learning of the decision learning in case of close charsets between pre-training and training - """ - pretrained_chars = list() - weights = checkpoint["{}_state_dict".format(state_dict_name)][key] - new_size = list(weights.size()) - new_size[0] = len(self.dataset.charset) + self.params["model_params"]["additional_tokens"] - new_weights = torch.zeros(new_size, device=weights.device, dtype=weights.dtype) - old_charset = checkpoint["charset"] if "charset" in checkpoint else self.params["model_params"]["old_charset"] - if not "bias" in key: - kaiming_uniform_(new_weights, nonlinearity="relu") - for i, c in enumerate(self.dataset.charset): - if c in old_charset: - new_weights[i] = weights[old_charset.index(c)] - pretrained_chars.append(c) - if "transfered_charset_last_is_ctc_blank" in self.params["model_params"] and self.params["model_params"]["transfered_charset_last_is_ctc_blank"]: - new_weights[-1] = weights[-1] - pretrained_chars.append("<blank>") - checkpoint["{}_state_dict".format(state_dict_name)][key] = new_weights - self.models[model_name].load_state_dict({key: checkpoint["{}_state_dict".format(state_dict_name)][key]}, strict=False) - print("Pretrained chars for {} ({}): {}".format(key, len(pretrained_chars), pretrained_chars)) - - def load_optimizers(self, checkpoint, reset_optimizer=False): - """ - Load the optimizer of each model - """ - for model_name in self.models.keys(): - new_params = dict() - if checkpoint and "optimizer_named_params_{}".format(model_name) in checkpoint: - self.optimizers_named_params_by_group[model_name] = checkpoint["optimizer_named_params_{}".format(model_name)] - # for progressively growing models - for name, param in self.models[model_name].named_parameters(): - existing = False - for gr in self.optimizers_named_params_by_group[model_name]: - if name in gr: - gr[name] = param - existing = True - break - if not existing: - new_params.update({name: param}) - else: - self.optimizers_named_params_by_group[model_name] = [dict(), ] - self.optimizers_named_params_by_group[model_name][0].update(self.models[model_name].named_parameters()) - - # Instantiate optimizer - self.reset_optimizer(model_name) - - # Handle learning rate schedulers - if "lr_schedulers" in self.params["training_params"] and self.params["training_params"]["lr_schedulers"]: - key = "all" if "all" in self.params["training_params"]["lr_schedulers"] else model_name - if key in self.params["training_params"]["lr_schedulers"]: - self.lr_schedulers[model_name] = self.params["training_params"]["lr_schedulers"][key]["class"]\ - (self.optimizers[model_name], **self.params["training_params"]["lr_schedulers"][key]["args"]) - - # Load optimizer state from past training - if checkpoint and not reset_optimizer: - self.optimizers[model_name].load_state_dict(checkpoint["optimizer_{}_state_dict".format(model_name)]) - # Load optimizer scheduler config from past training if used - if "lr_schedulers" in self.params["training_params"] and self.params["training_params"]["lr_schedulers"] \ - and "lr_scheduler_{}_state_dict".format(model_name) in checkpoint.keys(): - self.lr_schedulers[model_name].load_state_dict(checkpoint["lr_scheduler_{}_state_dict".format(model_name)]) - - # for progressively growing models, keeping learning rate - if checkpoint and new_params: - self.optimizers_named_params_by_group[model_name].append(new_params) - self.optimizers[model_name].add_param_group({"params": list(new_params.values())}) - - @staticmethod - def set_model_learnable(model, learnable=True): - for p in list(model.parameters()): - p.requires_grad = learnable - - def save_model(self, epoch, name, keep_weights=False): - """ - Save model weights and training info for curriculum learning or learning rate for instance - """ - if not self.is_master: - return - to_del = [] - for filename in os.listdir(self.paths["checkpoints"]): - if name in filename: - to_del.append(os.path.join(self.paths["checkpoints"], filename)) - path = os.path.join(self.paths["checkpoints"], "{}_{}.pt".format(name, epoch)) - content = { - 'optimizers_named_params': self.optimizers_named_params_by_group, - 'epoch': epoch, - 'step': self.latest_step, - "scaler_state_dict": self.scaler.state_dict(), - 'best': self.best, - "charset": self.dataset.charset - } - for model_name in self.optimizers: - content['optimizer_{}_state_dict'.format(model_name)] = self.optimizers[model_name].state_dict() - for model_name in self.lr_schedulers: - content["lr_scheduler_{}_state_dict".format(model_name)] = self.lr_schedulers[model_name].state_dict() - content = self.add_save_info(content) - for model_name in self.models.keys(): - content["{}_state_dict".format(model_name)] = self.models[model_name].state_dict() - torch.save(content, path) - if not keep_weights: - for path_to_del in to_del: - if path_to_del != path: - os.remove(path_to_del) - - def reset_optimizers(self): - """ - Reset learning rate of all optimizers - """ - for model_name in self.models.keys(): - self.reset_optimizer(model_name) - - def reset_optimizer(self, model_name): - """ - Reset optimizer learning rate for given model - """ - params = list(self.optimizers_named_params_by_group[model_name][0].values()) - key = "all" if "all" in self.params["training_params"]["optimizers"] else model_name - self.optimizers[model_name] = self.params["training_params"]["optimizers"][key]["class"](params, **self.params["training_params"]["optimizers"][key]["args"]) - for i in range(1, len(self.optimizers_named_params_by_group[model_name])): - self.optimizers[model_name].add_param_group({"params": list(self.optimizers_named_params_by_group[model_name][i].values())}) - - def save_params(self): - """ - Output text file containing a summary of all hyperparameters chosen for the training - """ - def compute_nb_params(module): - return sum([np.prod(p.size()) for p in list(module.parameters())]) - - def class_to_str_dict(my_dict): - for key in my_dict.keys(): - if callable(my_dict[key]): - my_dict[key] = my_dict[key].__name__ - elif isinstance(my_dict[key], np.ndarray): - my_dict[key] = my_dict[key].tolist() - elif isinstance(my_dict[key], dict): - my_dict[key] = class_to_str_dict(my_dict[key]) - return my_dict - - path = os.path.join(self.paths["results"], "params") - if os.path.isfile(path): - return - params = copy.deepcopy(self.params) - params = class_to_str_dict(params) - params["date"] = date.today().strftime("%d/%m/%Y") - total_params = 0 - for model_name in self.models.keys(): - current_params = compute_nb_params(self.models[model_name]) - params["model_params"]["models"][model_name] = [params["model_params"]["models"][model_name], "{:,}".format(current_params)] - total_params += current_params - params["model_params"]["total_params"] = "{:,}".format(total_params) - - params["hardware"] = dict() - if self.device != "cpu": - for i in range(self.params["training_params"]["nb_gpu"]): - params["hardware"][str(i)] = "{} {}".format(torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i)) - else: - params["hardware"]["0"] = "CPU" - params["software"] = { - "python_version": sys.version, - "pytorch_version": torch.__version__, - "cuda_version": torch.version.cuda, - "cudnn_version": torch.backends.cudnn.version(), - } - with open(path, 'w') as f: - json.dump(params, f, indent=4) - - def backward_loss(self, loss, retain_graph=False): - self.scaler.scale(loss).backward(retain_graph=retain_graph) - - def step_optimizers(self, increment_step=True, names=None): - for model_name in self.optimizers: - if names and model_name not in names: - continue - if "gradient_clipping" in self.params["training_params"] and model_name in self.params["training_params"]["gradient_clipping"]["models"]: - self.scaler.unscale_(self.optimizers[model_name]) - torch.nn.utils.clip_grad_norm_(self.models[model_name].parameters(), self.params["training_params"]["gradient_clipping"]["max"]) - self.scaler.step(self.optimizers[model_name]) - self.scaler.update() - self.latest_step += 1 - - def zero_optimizers(self, set_to_none=True): - for model_name in self.optimizers: - self.zero_optimizer(model_name, set_to_none) - - def zero_optimizer(self, model_name, set_to_none=True): - self.optimizers[model_name].zero_grad(set_to_none=set_to_none) - - def train(self): - """ - Main training loop - """ - # init tensorboard file and output param summary file - if self.is_master: - self.writer = SummaryWriter(self.paths["results"]) - self.save_params() - # init variables - self.begin_time = time() - focus_metric_name = self.params["training_params"]["focus_metric"] - nb_epochs = self.params["training_params"]["max_nb_epochs"] - interval_save_weights = self.params["training_params"]["interval_save_weights"] - metric_names = self.params["training_params"]["train_metrics"] - - display_values = None - # init curriculum learning - if "curriculum_learning" in self.params["training_params"].keys() and self.params["training_params"]["curriculum_learning"]: - self.init_curriculum() - # perform epochs - for num_epoch in range(self.latest_epoch+1, nb_epochs): - self.dataset.train_dataset.training_info = { - "epoch": self.latest_epoch, - "step": self.latest_step - } - self.phase = "train" - # Check maximum training time stop condition - if self.params["training_params"]["max_training_time"] and time() - self.begin_time > self.params["training_params"]["max_training_time"]: - break - # set models trainable - for model_name in self.models.keys(): - self.models[model_name].train() - self.latest_epoch = num_epoch - if self.dataset.train_dataset.curriculum_config: - self.dataset.train_dataset.curriculum_config["epoch"] = self.latest_epoch - # init epoch metrics values - self.metric_manager["train"] = MetricManager(metric_names=metric_names, dataset_name=self.dataset_name) - - with tqdm(total=len(self.dataset.train_loader.dataset)) as pbar: - pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs)) - # iterates over mini-batch data - for ind_batch, batch_data in enumerate(self.dataset.train_loader): - self.latest_batch = ind_batch + 1 - self.total_batch += 1 - # train on batch data and compute metrics - batch_values = self.train_batch(batch_data, metric_names) - batch_metrics = self.metric_manager["train"].compute_metrics(batch_values, metric_names) - batch_metrics["names"] = batch_data["names"] - batch_metrics["ids"] = batch_data["ids"] - # Merge metrics if Distributed Data Parallel is used - if self.params["training_params"]["use_ddp"]: - batch_metrics = self.merge_ddp_metrics(batch_metrics) - # Update learning rate via scheduler if one is used - if self.params["training_params"]["lr_schedulers"]: - for model_name in self.models: - key = "all" if "all" in self.params["training_params"]["lr_schedulers"] else model_name - if model_name in self.lr_schedulers and ind_batch % self.params["training_params"]["lr_schedulers"][key]["step_interval"] == 0: - self.lr_schedulers[model_name].step(len(batch_metrics["names"])) - if "lr" in metric_names: - self.writer.add_scalar("lr_{}".format(model_name), self.lr_schedulers[model_name].lr, self.lr_schedulers[model_name].step_num) - # Update dropout scheduler if used - if self.dropout_scheduler: - self.dropout_scheduler.step(len(batch_metrics["names"])) - self.dropout_scheduler.update_dropout_rate() - - # Add batch metrics values to epoch metrics values - self.metric_manager["train"].update_metrics(batch_metrics) - display_values = self.metric_manager["train"].get_display_values() - pbar.set_postfix(values=str(display_values)) - pbar.update(len(batch_data["names"])) - - # log metrics in tensorboard file - if self.is_master: - for key in display_values.keys(): - self.writer.add_scalar('{}_{}'.format(self.params["dataset_params"]["train"]["name"], key), display_values[key], num_epoch) - self.latest_train_metrics = display_values - - # evaluate and compute metrics for valid sets - if self.params["training_params"]["eval_on_valid"] and num_epoch % self.params["training_params"]["eval_on_valid_interval"] == 0: - for valid_set_name in self.dataset.valid_loaders.keys(): - # evaluate set and compute metrics - eval_values = self.evaluate(valid_set_name) - self.latest_valid_metrics = eval_values - # log valid metrics in tensorboard file - if self.is_master: - for key in eval_values.keys(): - self.writer.add_scalar('{}_{}'.format(valid_set_name, key), eval_values[key], num_epoch) - if valid_set_name == self.params["training_params"]["set_name_focus_metric"] and (self.best is None or \ - (eval_values[focus_metric_name] <= self.best and self.params["training_params"]["expected_metric_value"] == "low") or\ - (eval_values[focus_metric_name] >= self.best and self.params["training_params"]["expected_metric_value"] == "high")): - self.save_model(epoch=num_epoch, name="best") - self.best = eval_values[focus_metric_name] - - # Handle curriculum learning update - if self.dataset.train_dataset.curriculum_config: - self.check_and_update_curriculum() - - if "curriculum_model" in self.params["model_params"] and self.params["model_params"]["curriculum_model"]: - self.update_curriculum_model() - - # save model weights - if self.is_master: - self.save_model(epoch=num_epoch, name="last") - if interval_save_weights and num_epoch % interval_save_weights == 0: - self.save_model(epoch=num_epoch, name="weigths", keep_weights=True) - self.writer.flush() - - def evaluate(self, set_name, **kwargs): - """ - Main loop for validation - """ - self.phase = "eval" - loader = self.dataset.valid_loaders[set_name] - # Set models in eval mode - for model_name in self.models.keys(): - self.models[model_name].eval() - metric_names = self.params["training_params"]["eval_metrics"] - display_values = None - - # initialize epoch metrics - self.metric_manager[set_name] = MetricManager(metric_names, dataset_name=self.dataset_name) - with tqdm(total=len(loader.dataset)) as pbar: - pbar.set_description("Evaluation E{}".format(self.latest_epoch)) - with torch.no_grad(): - # iterate over batch data - for ind_batch, batch_data in enumerate(loader): - self.latest_batch = ind_batch + 1 - # eval batch data and compute metrics - batch_values = self.evaluate_batch(batch_data, metric_names) - batch_metrics = self.metric_manager[set_name].compute_metrics(batch_values, metric_names) - batch_metrics["names"] = batch_data["names"] - batch_metrics["ids"] = batch_data["ids"] - # merge metrics values if Distributed Data Parallel is used - if self.params["training_params"]["use_ddp"]: - batch_metrics = self.merge_ddp_metrics(batch_metrics) - - # add batch metrics to epoch metrics - self.metric_manager[set_name].update_metrics(batch_metrics) - display_values = self.metric_manager[set_name].get_display_values() - - pbar.set_postfix(values=str(display_values)) - pbar.update(len(batch_data["names"])) - if "cer_by_nb_cols" in metric_names: - self.log_cer_by_nb_cols(set_name) - return display_values - - def predict(self, custom_name, sets_list, metric_names, output=False): - """ - Main loop for evaluation - """ - self.phase = "predict" - metric_names = metric_names.copy() - self.dataset.generate_test_loader(custom_name, sets_list) - loader = self.dataset.test_loaders[custom_name] - # Set models in eval mode - for model_name in self.models.keys(): - self.models[model_name].eval() - - # initialize epoch metrics - self.metric_manager[custom_name] = MetricManager(metric_names, self.dataset_name) - - with tqdm(total=len(loader.dataset)) as pbar: - pbar.set_description("Prediction") - with torch.no_grad(): - for ind_batch, batch_data in enumerate(loader): - # iterates over batch data - self.latest_batch = ind_batch + 1 - # eval batch data and compute metrics - batch_values = self.evaluate_batch(batch_data, metric_names) - batch_metrics = self.metric_manager[custom_name].compute_metrics(batch_values, metric_names) - batch_metrics["names"] = batch_data["names"] - batch_metrics["ids"] = batch_data["ids"] - # merge batch metrics if Distributed Data Parallel is used - if self.params["training_params"]["use_ddp"]: - batch_metrics = self.merge_ddp_metrics(batch_metrics) - - # add batch metrics to epoch metrics - self.metric_manager[custom_name].update_metrics(batch_metrics) - display_values = self.metric_manager[custom_name].get_display_values() - - pbar.set_postfix(values=str(display_values)) - pbar.update(len(batch_data["names"])) - - self.dataset.remove_test_dataset(custom_name) - # output metrics values if requested - if output: - if "pred" in metric_names: - self.output_pred(custom_name) - metrics = self.metric_manager[custom_name].get_display_values(output=True) - path = os.path.join(self.paths["results"], "predict_{}_{}.txt".format(custom_name, self.latest_epoch)) - with open(path, "w") as f: - for metric_name in metrics.keys(): - f.write("{}: {}\n".format(metric_name, metrics[metric_name])) - - def output_pred(self, name): - path = os.path.join(self.paths["results"], "pred_{}_{}.txt".format(name, self.latest_epoch)) - pred = "\n".join(self.metric_manager[name].get("pred")) - with open(path, "w") as f: - f.write(pred) - - def launch_ddp(self): - """ - Initialize Distributed Data Parallel system - """ - mp.set_start_method('fork', force=True) - os.environ['MASTER_ADDR'] = self.ddp_config["address"] - os.environ['MASTER_PORT'] = str(self.ddp_config["port"]) - dist.init_process_group(self.ddp_config["backend"], rank=self.ddp_config["rank"], world_size=self.params["training_params"]["nb_gpu"]) - torch.cuda.set_device(self.ddp_config["rank"]) - random.seed(self.manual_seed) - np.random.seed(self.manual_seed) - torch.manual_seed(self.manual_seed) - torch.cuda.manual_seed(self.manual_seed) - - def merge_ddp_metrics(self, metrics): - """ - Merge metrics when Distributed Data Parallel is used - """ - for metric_name in metrics.keys(): - if metric_name in ["edit_words", "nb_words", "edit_chars", "nb_chars", "edit_chars_force_len", - "edit_chars_curr", "nb_chars_curr", "ids"]: - metrics[metric_name] = self.cat_ddp_metric(metrics[metric_name]) - elif metric_name in ["nb_samples", "loss", "loss_ce", "loss_ctc", "loss_ce_end"]: - metrics[metric_name] = self.sum_ddp_metric(metrics[metric_name], average=False) - return metrics - - def sum_ddp_metric(self, metric, average=False): - """ - Sum metrics for Distributed Data Parallel - """ - sum = torch.tensor(metric[0]).to(self.device) - dist.all_reduce(sum, op=dist.ReduceOp.SUM) - if average: - sum.true_divide(dist.get_world_size()) - return [sum.item(), ] - - def cat_ddp_metric(self, metric): - """ - Concatenate metrics for Distributed Data Parallel - """ - tensor = torch.tensor(metric).unsqueeze(0).to(self.device) - res = [torch.zeros(tensor.size()).long().to(self.device) for _ in range(dist.get_world_size())] - dist.all_gather(res, tensor) - return list(torch.cat(res, dim=0).flatten().cpu().numpy()) - - @staticmethod - def cleanup(): - dist.destroy_process_group() - - def train_batch(self, batch_data, metric_names): - raise NotImplementedError - - def evaluate_batch(self, batch_data, metric_names): - raise NotImplementedError - - def init_curriculum(self): - raise NotImplementedError - - def update_curriculum(self): - raise NotImplementedError - - def add_checkpoint_info(self, load_mode="last", **kwargs): - for filename in os.listdir(self.paths["checkpoints"]): - if load_mode in filename: - checkpoint_path = os.path.join(self.paths["checkpoints"], filename) - checkpoint = torch.load(checkpoint_path) - for key in kwargs.keys(): - checkpoint[key] = kwargs[key] - torch.save(checkpoint, checkpoint_path) - return - self.save_model(self.latest_epoch, "last") - - def load_save_info(self, info_dict): - """ - Load curriculum info from saved model info - """ - if "curriculum_config" in info_dict.keys(): - self.dataset.train_dataset.curriculum_config = info_dict["curriculum_config"] - - def add_save_info(self, info_dict): - """ - Add curriculum info to model info to be saved - """ - info_dict["curriculum_config"] = self.dataset.train_dataset.curriculum_config - return info_dict \ No newline at end of file diff --git a/basic/metric_manager.py b/basic/metric_manager.py deleted file mode 100644 index 6c8576863a2d26a85c2d5dbc4321323dedf870a0..0000000000000000000000000000000000000000 --- a/basic/metric_manager.py +++ /dev/null @@ -1,538 +0,0 @@ - -from Datasets.dataset_formatters.rimes_formatter import SEM_MATCHING_TOKENS as RIMES_MATCHING_TOKENS -from Datasets.dataset_formatters.read2016_formatter import SEM_MATCHING_TOKENS as READ_MATCHING_TOKENS -from Datasets.dataset_formatters.simara_formatter import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS -import re -import networkx as nx -import editdistance -import numpy as np -from dan.post_processing import PostProcessingModuleREAD, PostProcessingModuleRIMES, PostProcessingModuleSIMARA - - -class MetricManager: - - def __init__(self, metric_names, dataset_name): - self.dataset_name = dataset_name - if "READ" in dataset_name and "page" in dataset_name: - self.post_processing_module = PostProcessingModuleREAD - self.matching_tokens = READ_MATCHING_TOKENS - self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_read - elif "RIMES" in dataset_name and "page" in dataset_name: - self.post_processing_module = PostProcessingModuleRIMES - self.matching_tokens = RIMES_MATCHING_TOKENS - self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_rimes - elif "simara" in dataset_name and "page" in dataset_name: - self.post_processing_module = PostProcessingModuleSIMARA - self.matching_tokens = SIMARA_MATCHING_TOKENS - self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_simara - else: - self.matching_tokens = dict() - - self.layout_tokens = "".join(list(self.matching_tokens.keys()) + list(self.matching_tokens.values())) - if len(self.layout_tokens) == 0: - self.layout_tokens = None - self.metric_names = metric_names - self.epoch_metrics = None - - self.linked_metrics = { - "cer": ["edit_chars", "nb_chars"], - "wer": ["edit_words", "nb_words"], - "loer": ["edit_graph", "nb_nodes_and_edges", "nb_pp_op_layout", "nb_gt_layout_token"], - "precision": ["precision", "weights"], - "map_cer_per_class": ["map_cer", ], - "layout_precision_per_class_per_threshold": ["map_cer", ], - } - - self.init_metrics() - - def init_metrics(self): - """ - Initialization of the metrics specified in metrics_name - """ - self.epoch_metrics = { - "nb_samples": list(), - "names": list(), - "ids": list(), - } - - for metric_name in self.metric_names: - if metric_name in self.linked_metrics: - for linked_metric_name in self.linked_metrics[metric_name]: - if linked_metric_name not in self.epoch_metrics.keys(): - self.epoch_metrics[linked_metric_name] = list() - else: - self.epoch_metrics[metric_name] = list() - - def update_metrics(self, batch_metrics): - """ - Add batch metrics to the metrics - """ - for key in batch_metrics.keys(): - if key in self.epoch_metrics: - self.epoch_metrics[key] += batch_metrics[key] - - def get_display_values(self, output=False): - """ - format metrics values for shell display purposes - """ - metric_names = self.metric_names.copy() - if output: - metric_names.extend(["nb_samples"]) - display_values = dict() - for metric_name in metric_names: - value = None - if output: - if metric_name in ["nb_samples", "weights"]: - value = np.sum(self.epoch_metrics[metric_name]) - elif metric_name in ["time", ]: - total_time = np.sum(self.epoch_metrics[metric_name]) - sample_time = total_time / np.sum(self.epoch_metrics["nb_samples"]) - display_values["sample_time"] = round(sample_time, 4) - value = total_time - elif metric_name == "loer": - display_values["pper"] = round(np.sum(self.epoch_metrics["nb_pp_op_layout"]) / np.sum(self.epoch_metrics["nb_gt_layout_token"]), 4) - elif metric_name == "map_cer_per_class": - value = compute_global_mAP_per_class(self.epoch_metrics["map_cer"]) - for key in value.keys(): - display_values["map_cer_" + key] = round(value[key], 4) - continue - elif metric_name == "layout_precision_per_class_per_threshold": - value = compute_global_precision_per_class_per_threshold(self.epoch_metrics["map_cer"]) - for key_class in value.keys(): - for threshold in value[key_class].keys(): - display_values["map_cer_{}_{}".format(key_class, threshold)] = round( - value[key_class][threshold], 4) - continue - if metric_name == "cer": - value = np.sum(self.epoch_metrics["edit_chars"]) / np.sum(self.epoch_metrics["nb_chars"]) - if output: - display_values["nb_chars"] = np.sum(self.epoch_metrics["nb_chars"]) - elif metric_name == "wer": - value = np.sum(self.epoch_metrics["edit_words"]) / np.sum(self.epoch_metrics["nb_words"]) - if output: - display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"]) - elif metric_name in ["loss", "loss_ctc", "loss_ce", "syn_max_lines"]: - value = np.average(self.epoch_metrics[metric_name], weights=np.array(self.epoch_metrics["nb_samples"])) - elif metric_name == "map_cer": - value = compute_global_mAP(self.epoch_metrics[metric_name]) - elif metric_name == "loer": - value = np.sum(self.epoch_metrics["edit_graph"]) / np.sum(self.epoch_metrics["nb_nodes_and_edges"]) - elif value is None: - continue - - display_values[metric_name] = round(value, 4) - return display_values - - def compute_metrics(self, values, metric_names): - metrics = { - "nb_samples": [values["nb_samples"], ], - } - for v in ["weights", "time"]: - if v in values: - metrics[v] = [values[v]] - for metric_name in metric_names: - if metric_name == "cer": - metrics["edit_chars"] = [edit_cer_from_string(u, v, self.layout_tokens) for u, v in zip(values["str_y"], values["str_x"])] - metrics["nb_chars"] = [nb_chars_cer_from_string(gt, self.layout_tokens) for gt in values["str_y"]] - elif metric_name == "wer": - split_gt = [format_string_for_wer(gt, self.layout_tokens) for gt in values["str_y"]] - split_pred = [format_string_for_wer(pred, self.layout_tokens) for pred in values["str_x"]] - metrics["edit_words"] = [edit_wer_from_formatted_split_text(gt, pred) for (gt, pred) in zip(split_gt, split_pred)] - metrics["nb_words"] = [len(gt) for gt in split_gt] - elif metric_name in ["loss_ctc", "loss_ce", "loss", "syn_max_lines", ]: - metrics[metric_name] = [values[metric_name], ] - elif metric_name == "map_cer": - pp_pred = list() - pp_score = list() - for pred, score in zip(values["str_x"], values["confidence_score"]): - pred_score = self.post_processing_module().post_process(pred, score) - pp_pred.append(pred_score[0]) - pp_score.append(pred_score[1]) - metrics[metric_name] = [compute_layout_mAP_per_class(y, x, conf, self.matching_tokens) for x, conf, y in zip(pp_pred, pp_score, values["str_y"])] - elif metric_name == "loer": - pp_pred = list() - metrics["nb_pp_op_layout"] = list() - for pred in values["str_x"]: - pp_module = self.post_processing_module() - pp_pred.append(pp_module.post_process(pred)) - metrics["nb_pp_op_layout"].append(pp_module.num_op) - metrics["nb_gt_layout_token"] = [len(keep_only_tokens(str_x, self.layout_tokens)) for str_x in values["str_x"]] - edit_and_num_items = [self.edit_and_num_edge_nodes(y, x) for x, y in zip(pp_pred, values["str_y"])] - metrics["edit_graph"], metrics["nb_nodes_and_edges"] = [ei[0] for ei in edit_and_num_items], [ei[1] for ei in edit_and_num_items] - return metrics - - def get(self, name): - return self.epoch_metrics[name] - - -def keep_only_tokens(str, tokens): - """ - Remove all but layout tokens from string - """ - return re.sub('([^' + tokens + '])', '', str) - - -def keep_all_but_tokens(str, tokens): - """ - Remove all layout tokens from string - """ - return re.sub('([' + tokens + '])', '', str) - - -def edit_cer_from_string(gt, pred, layout_tokens=None): - """ - Format and compute edit distance between two strings at character level - """ - gt = format_string_for_cer(gt, layout_tokens) - pred = format_string_for_cer(pred, layout_tokens) - return editdistance.eval(gt, pred) - - -def nb_chars_cer_from_string(gt, layout_tokens=None): - """ - Compute length after formatting of ground truth string - """ - return len(format_string_for_cer(gt, layout_tokens)) - - -def edit_wer_from_string(gt, pred, layout_tokens=None): - """ - Format and compute edit distance between two strings at word level - """ - split_gt = format_string_for_wer(gt, layout_tokens) - split_pred = format_string_for_wer(pred, layout_tokens) - return edit_wer_from_formatted_split_text(split_gt, split_pred) - - -def format_string_for_wer(str, layout_tokens): - """ - Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space - """ - str = re.sub('([\[\]{}/\\()\"\'&+*=<>?.;:,!\-—_€#%°])', r' \1 ', str) # punctuation processed as word - if layout_tokens is not None: - str = keep_all_but_tokens(str, layout_tokens) # remove layout tokens from metric - str = re.sub('([ \n])+', " ", str).strip() # keep only one space character - return str.split(" ") - - -def format_string_for_cer(str, layout_tokens): - """ - Format string for CER computation: remove layout tokens and extra spaces - """ - if layout_tokens is not None: - str = keep_all_but_tokens(str, layout_tokens) # remove layout tokens from metric - str = re.sub('([\n])+', "\n", str) # remove consecutive line breaks - str = re.sub('([ ])+', " ", str).strip() # remove consecutive spaces - return str - - -def edit_wer_from_formatted_split_text(gt, pred): - """ - Compute edit distance at word level from formatted string as list - """ - return editdistance.eval(gt, pred) - - -def extract_by_tokens(input_str, begin_token, end_token, associated_score=None, order_by_score=False): - """ - Extract list of text regions by begin and end tokens - Order the list by confidence score - """ - if order_by_score: - assert associated_score is not None - res = list() - for match in re.finditer("{}[^{}]*{}".format(begin_token, end_token, end_token), input_str): - begin, end = match.regs[0] - if order_by_score: - res.append({ - "confidence": np.mean([associated_score[begin], associated_score[end-1]]), - "content": input_str[begin+1:end-1] - }) - else: - res.append(input_str[begin+1:end-1]) - if order_by_score: - res = sorted(res, key=lambda x: x["confidence"], reverse=True) - res = [r["content"] for r in res] - return res - - -def compute_layout_precision_per_threshold(gt, pred, score, begin_token, end_token, layout_tokens, return_weight=True): - """ - Compute average precision of a given class for CER threshold from 5% to 50% with a step of 5% - """ - pred_list = extract_by_tokens(pred, begin_token, end_token, associated_score=score, order_by_score=True) - gt_list = extract_by_tokens(gt, begin_token, end_token) - pred_list = [keep_all_but_tokens(p, layout_tokens) for p in pred_list] - gt_list = [keep_all_but_tokens(gt, layout_tokens) for gt in gt_list] - precision_per_threshold = [compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold/100) for threshold in range(5, 51, 5)] - if return_weight: - return precision_per_threshold, len(gt_list) - return precision_per_threshold - - -def compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold): - """ - Compute average precision of a given class for a given CER threshold - """ - remaining_gt_list = gt_list.copy() - num_true = len(gt_list) - correct = np.zeros((len(pred_list)), dtype=np.bool) - for i, pred in enumerate(pred_list): - if len(remaining_gt_list) == 0: - break - cer_with_gt = [edit_cer_from_string(gt, pred)/nb_chars_cer_from_string(gt) for gt in remaining_gt_list] - cer, ind = np.min(cer_with_gt), np.argmin(cer_with_gt) - if cer <= threshold: - correct[i] = True - del remaining_gt_list[ind] - precision = np.cumsum(correct, dtype=np.int) / np.arange(1, len(pred_list)+1) - recall = np.cumsum(correct, dtype=np.int) / num_true - max_precision_from_recall = np.maximum.accumulate(precision[::-1])[::-1] - recall_diff = (recall - np.concatenate([np.array([0, ]), recall[:-1]])) - P = np.sum(recall_diff * max_precision_from_recall) - return P - - -def compute_layout_mAP_per_class(gt, pred, score, tokens): - """ - Compute the mAP_cer for each class for a given sample - """ - layout_tokens = "".join(list(tokens.keys())) - AP_per_class = dict() - for token in tokens.keys(): - if token in gt: - AP_per_class[token] = compute_layout_precision_per_threshold(gt, pred, score, token, tokens[token], layout_tokens=layout_tokens) - return AP_per_class - - -def compute_global_mAP(list_AP_per_class): - """ - Compute the global mAP_cer for several samples - """ - weights_per_doc = list() - mAP_per_doc = list() - for doc_AP_per_class in list_AP_per_class: - APs = np.array([np.mean(doc_AP_per_class[key][0]) for key in doc_AP_per_class.keys()]) - weights = np.array([doc_AP_per_class[key][1] for key in doc_AP_per_class.keys()]) - if np.sum(weights) == 0: - mAP_per_doc.append(0) - else: - mAP_per_doc.append(np.average(APs, weights=weights)) - weights_per_doc.append(np.sum(weights)) - if np.sum(weights_per_doc) == 0: - return 0 - return np.average(mAP_per_doc, weights=weights_per_doc) - - -def compute_global_mAP_per_class(list_AP_per_class): - """ - Compute the mAP_cer per class for several samples - """ - mAP_per_class = dict() - for doc_AP_per_class in list_AP_per_class: - for key in doc_AP_per_class.keys(): - if key not in mAP_per_class: - mAP_per_class[key] = { - "AP": list(), - "weights": list() - } - mAP_per_class[key]["AP"].append(np.mean(doc_AP_per_class[key][0])) - mAP_per_class[key]["weights"].append(doc_AP_per_class[key][1]) - for key in mAP_per_class.keys(): - mAP_per_class[key] = np.average(mAP_per_class[key]["AP"], weights=mAP_per_class[key]["weights"]) - return mAP_per_class - - -def compute_global_precision_per_class_per_threshold(list_AP_per_class): - """ - Compute the mAP_cer per class and per threshold for several samples - """ - mAP_per_class = dict() - for doc_AP_per_class in list_AP_per_class: - for key in doc_AP_per_class.keys(): - if key not in mAP_per_class: - mAP_per_class[key] = dict() - for threshold in range(5, 51, 5): - mAP_per_class[key][threshold] = { - "precision": list(), - "weights": list() - } - for i, threshold in enumerate(range(5, 51, 5)): - mAP_per_class[key][threshold]["precision"].append(np.mean(doc_AP_per_class[key][0][i])) - mAP_per_class[key][threshold]["weights"].append(doc_AP_per_class[key][1]) - for key_class in mAP_per_class.keys(): - for threshold in mAP_per_class[key_class]: - mAP_per_class[key_class][threshold] = np.average(mAP_per_class[key_class][threshold]["precision"], weights=mAP_per_class[key_class][threshold]["weights"]) - return mAP_per_class - - -def str_to_graph_read(str): - """ - Compute graph from string of layout tokens for the READ 2016 dataset at single-page and double-page levels - """ - begin_layout_tokens = "".join(list(READ_MATCHING_TOKENS.keys())) - layout_token_sequence = keep_only_tokens(str, begin_layout_tokens) - g = nx.DiGraph() - g.add_node("D", type="document", level=4, page=0) - num = { - "â“Ÿ": 0, - "â“": 0, - "â“‘": 0, - "â“": 0, - "â“¢": 0 - } - previous_top_level_node = None - previous_middle_level_node = None - previous_low_level_node = None - for ind, c in enumerate(layout_token_sequence): - num[c] += 1 - if c == "â“Ÿ": - node_name = "P_{}".format(num[c]) - g.add_node(node_name, type="page", level=3, page=num["â“Ÿ"]) - g.add_edge("D", node_name) - if previous_top_level_node: - g.add_edge(previous_top_level_node, node_name) - previous_top_level_node = node_name - previous_middle_level_node = None - previous_low_level_node = None - if c in "â“â“¢": - node_name = "{}_{}".format("N" if c == "â“" else "S", num[c]) - g.add_node(node_name, type="number" if c == "â“" else "section", level=2, page=num["â“Ÿ"]) - g.add_edge(previous_top_level_node, node_name) - if previous_middle_level_node: - g.add_edge(previous_middle_level_node, node_name) - previous_middle_level_node = node_name - previous_low_level_node = None - if c in "â“â“‘": - node_name = "{}_{}".format("A" if c == "â“" else "B", num[c]) - g.add_node(node_name, type="annotation" if c == "â“" else "body", level=1, page=num["â“Ÿ"]) - g.add_edge(previous_middle_level_node, node_name) - if previous_low_level_node: - g.add_edge(previous_low_level_node, node_name) - previous_low_level_node = node_name - return g - - -def str_to_graph_rimes(str): - """ - Compute graph from string of layout tokens for the RIMES dataset at page level - """ - begin_layout_tokens = "".join(list(RIMES_MATCHING_TOKENS.keys())) - layout_token_sequence = keep_only_tokens(str, begin_layout_tokens) - g = nx.DiGraph() - g.add_node("D", type="document", level=2, page=0) - token_name_dict = { - "â“‘": "B", - "â“ž": "O", - "â“¡": "R", - "â“¢": "S", - "ⓦ": "W", - "ⓨ": "Y", - "â“Ÿ": "P" - } - num = dict() - previous_node = None - for token in begin_layout_tokens: - num[token] = 0 - for ind, c in enumerate(layout_token_sequence): - num[c] += 1 - node_name = "{}_{}".format(token_name_dict[c], num[c]) - g.add_node(node_name, type=token_name_dict[c], level=1, page=0) - g.add_edge("D", node_name) - if previous_node: - g.add_edge(previous_node, node_name) - previous_node = node_name - return g - - -def str_to_graph_simara(str): - """ - Compute graph from string of layout tokens for the SIMARA dataset at page level - """ - begin_layout_tokens = "".join(list(SIMARA_MATCHING_TOKENS.keys())) - layout_token_sequence = keep_only_tokens(str, begin_layout_tokens) - g = nx.DiGraph() - g.add_node("D", type="document", level=2, page=0) - token_name_dict = { - "ⓘ": "I", - "â““": "D", - "â“¢": "S", - "â“’": "C", - "â“Ÿ": "P", - "â“": "A" - } - num = dict() - previous_node = None - for token in begin_layout_tokens: - num[token] = 0 - for ind, c in enumerate(layout_token_sequence): - num[c] += 1 - node_name = "{}_{}".format(token_name_dict[c], num[c]) - g.add_node(node_name, type=token_name_dict[c], level=1, page=0) - g.add_edge("D", node_name) - if previous_node: - g.add_edge(previous_node, node_name) - previous_node = node_name - return g - - -def graph_edit_distance_by_page_read(g1, g2): - """ - Compute graph edit distance page by page for the READ 2016 dataset - """ - num_pages_g1 = len([n for n in g1.nodes().items() if n[1]["level"] == 3]) - num_pages_g2 = len([n for n in g2.nodes().items() if n[1]["level"] == 3]) - page_graphs_1 = [g1.subgraph([n[0] for n in g1.nodes().items() if n[1]["page"] == num_page]) for num_page in range(1, num_pages_g1+1)] - page_graphs_2 = [g2.subgraph([n[0] for n in g2.nodes().items() if n[1]["page"] == num_page]) for num_page in range(1, num_pages_g2+1)] - edit = 0 - for i in range(max(len(page_graphs_1), len(page_graphs_2))): - page_1 = page_graphs_1[i] if i < len(page_graphs_1) else nx.DiGraph() - page_2 = page_graphs_2[i] if i < len(page_graphs_2) else nx.DiGraph() - edit += graph_edit_distance(page_1, page_2) - return edit - - -def graph_edit_distance(g1, g2): - """ - Compute graph edit distance between two graphs - """ - for v in nx.optimize_graph_edit_distance(g1, g2, - node_ins_cost=lambda node: 1, - node_del_cost=lambda node: 1, - node_subst_cost=lambda node1, node2: 0 if node1["type"] == node2["type"] else 1, - edge_ins_cost=lambda edge: 1, - edge_del_cost=lambda edge: 1, - edge_subst_cost=lambda edge1, edge2: 0 if edge1 == edge2 else 1 - ): - new_edit = v - return new_edit - - -def edit_and_num_items_for_ged_from_str_read(str_gt, str_pred): - """ - Compute graph edit distance and num nodes/edges for normalized graph edit distance - For the READ 2016 dataset - """ - g_gt = str_to_graph_read(str_gt) - g_pred = str_to_graph_read(str_pred) - return graph_edit_distance_by_page_read(g_gt, g_pred), g_gt.number_of_nodes() + g_gt.number_of_edges() - - -def edit_and_num_items_for_ged_from_str_rimes(str_gt, str_pred): - """ - Compute graph edit distance and num nodes/edges for normalized graph edit distance - For the RIMES dataset - """ - g_gt = str_to_graph_rimes(str_gt) - g_pred = str_to_graph_rimes(str_pred) - return graph_edit_distance(g_gt, g_pred), g_gt.number_of_nodes() + g_gt.number_of_edges() - - -def edit_and_num_items_for_ged_from_str_simara(str_gt, str_pred): - """ - Compute graph edit distance and num nodes/edges for normalized graph edit distance - For the SIMARA dataset - """ - g_gt = str_to_graph_simara(str_gt) - g_pred = str_to_graph_simara(str_pred) - return graph_edit_distance(g_gt, g_pred), g_gt.number_of_nodes() + g_gt.number_of_edges() diff --git a/basic/scheduler.py b/basic/scheduler.py deleted file mode 100644 index 6c875c1da2f57ab0306635cbc77309c2b83af116..0000000000000000000000000000000000000000 --- a/basic/scheduler.py +++ /dev/null @@ -1,51 +0,0 @@ - -from torch.nn import Dropout, Dropout2d -import numpy as np - - -class DropoutScheduler: - - def __init__(self, models, function, T=1e5): - """ - T: number of gradient updates to converge - """ - - self.teta_list = list() - self.init_teta_list(models) - self.function = function - self.T = T - self.step_num = 0 - - def step(self): - self.step(1) - - def step(self, num): - self.step_num += num - - def init_teta_list(self, models): - for model_name in models.keys(): - self.init_teta_list_module(models[model_name]) - - def init_teta_list_module(self, module): - for child in module.children(): - if isinstance(child, Dropout) or isinstance(child, Dropout2d): - self.teta_list.append([child, child.p]) - else: - self.init_teta_list_module(child) - - def update_dropout_rate(self): - for (module, p) in self.teta_list: - module.p = self.function(p, self.step_num, self.T) - - -def exponential_dropout_scheduler(dropout_rate, step, max_step): - return dropout_rate * (1 - np.exp(-10 * step / max_step)) - - -def exponential_scheduler(init_value, end_value, step, max_step): - step = min(step, max_step-1) - return init_value - (init_value - end_value) * (1 - np.exp(-10*step/max_step)) - - -def linear_scheduler(init_value, end_value, step, max_step): - return init_value + step * (end_value - init_value) / max_step \ No newline at end of file diff --git a/basic/transforms.py b/basic/transforms.py deleted file mode 100644 index 18c8084dafe668025ea7ae58ce99384b58a672a1..0000000000000000000000000000000000000000 --- a/basic/transforms.py +++ /dev/null @@ -1,438 +0,0 @@ - -import numpy as np -from numpy import random -from PIL import Image, ImageOps -from cv2 import erode, dilate, normalize -import cv2 -import math -from basic.utils import randint, rand_uniform, rand -from torchvision.transforms import RandomPerspective, RandomCrop, ColorJitter, GaussianBlur, RandomRotation -from torchvision.transforms.functional import InterpolationMode - -""" -Each transform class defined here takes as input a PIL Image and returns the modified PIL Image -""" - - -class SignFlipping: - """ - Color inversion - """ - - def __init__(self): - pass - - def __call__(self, x): - return ImageOps.invert(x) - - -class DPIAdjusting: - """ - Resolution modification - """ - - def __init__(self, factor, preserve_ratio): - self.factor = factor - - def __call__(self, x): - w, h = x.size - return x.resize((int(np.ceil(w * self.factor)), int(np.ceil(h * self.factor))), Image.BILINEAR) - - -class Dilation: - """ - OCR: stroke width increasing - """ - - def __init__(self, kernel, iterations): - self.kernel = np.ones(kernel, np.uint8) - self.iterations = iterations - - def __call__(self, x): - return Image.fromarray(dilate(np.array(x), self.kernel, iterations=self.iterations)) - - -class Erosion: - """ - OCR: stroke width decreasing - """ - - def __init__(self, kernel, iterations): - self.kernel = np.ones(kernel, np.uint8) - self.iterations = iterations - - def __call__(self, x): - return Image.fromarray(erode(np.array(x), self.kernel, iterations=self.iterations)) - - -class GaussianNoise: - """ - Add Gaussian Noise - """ - - def __init__(self, std): - self.std = std - - def __call__(self, x): - x_np = np.array(x) - mean, std = np.mean(x_np), np.std(x_np) - std = math.copysign(max(abs(std), 0.000001), std) - min_, max_ = np.min(x_np,), np.max(x_np) - normal_noise = np.random.randn(*x_np.shape) - if len(x_np.shape) == 3 and x_np.shape[2] == 3 and np.all(x_np[:, :, 0] == x_np[:, :, 1]) and np.all(x_np[:, :, 0] == x_np[:, :, 2]): - normal_noise[:, :, 1] = normal_noise[:, :, 2] = normal_noise[:, :, 0] - x_np = ((x_np-mean)/std + normal_noise*self.std) * std + mean - x_np = normalize(x_np, x_np, max_, min_, cv2.NORM_MINMAX) - - return Image.fromarray(x_np.astype(np.uint8)) - - -class Sharpen: - """ - Add Gaussian Noise - """ - - def __init__(self, alpha, strength): - self.alpha = alpha - self.strength = strength - - def __call__(self, x): - x_np = np.array(x) - id_matrix = np.array([[0, 0, 0], - [0, 1, 0], - [0, 0, 0]] - ) - effect_matrix = np.array([[1, 1, 1], - [1, -(8+self.strength), 1], - [1, 1, 1]] - ) - kernel = (1 - self.alpha) * id_matrix - self.alpha * effect_matrix - kernel = np.expand_dims(kernel, axis=2) - kernel = np.concatenate([kernel, kernel, kernel], axis=2) - sharpened = cv2.filter2D(x_np, -1, kernel=kernel[:, :, 0]) - return Image.fromarray(sharpened.astype(np.uint8)) - - -class ZoomRatio: - """ - Crop by ratio - Preserve dimensions if keep_dim = True (= zoom) - """ - - def __init__(self, ratio_h, ratio_w, keep_dim=True): - self.ratio_w = ratio_w - self.ratio_h = ratio_h - self.keep_dim = keep_dim - - def __call__(self, x): - w, h = x.size - x = RandomCrop((int(h * self.ratio_h), int(w * self.ratio_w)))(x) - if self.keep_dim: - x = x.resize((w, h), Image.BILINEAR) - return x - - -class ElasticDistortion: - - def __init__(self, kernel_size=(7, 7), sigma=5, alpha=1): - - self.kernel_size = kernel_size - self.sigma = sigma - self.alpha = alpha - - def __call__(self, x): - x_np = np.array(x) - - h, w = x_np.shape[:2] - - dx = np.random.uniform(-1, 1, (h, w)) - dy = np.random.uniform(-1, 1, (h, w)) - - x_gauss = cv2.GaussianBlur(dx, self.kernel_size, self.sigma) - y_gauss = cv2.GaussianBlur(dy, self.kernel_size, self.sigma) - - n = np.sqrt(x_gauss**2 + y_gauss**2) - - nd_x = self.alpha * x_gauss / n - nd_y = self.alpha * y_gauss / n - - ind_y, ind_x = np.indices((h, w), dtype=np.float32) - - map_x = nd_x + ind_x - map_x = map_x.reshape(h, w).astype(np.float32) - map_y = nd_y + ind_y - map_y = map_y.reshape(h, w).astype(np.float32) - - dst = cv2.remap(x_np, map_x, map_y, cv2.INTER_LINEAR) - return Image.fromarray(dst.astype(np.uint8)) - - -class Tightening: - """ - Reduce interline spacing - """ - - def __init__(self, color=255, remove_proba=0.75): - self.color = color - self.remove_proba = remove_proba - - def __call__(self, x): - x_np = np.array(x) - interline_indices = [np.all(line == 255) for line in x_np] - indices_to_removed = np.logical_and(np.random.choice([True, False], size=len(x_np), replace=True, p=[self.remove_proba, 1-self.remove_proba]), interline_indices) - new_x = x_np[np.logical_not(indices_to_removed)] - return Image.fromarray(new_x.astype(np.uint8)) - - -def get_list_augmenters(img, aug_configs, fill_value): - """ - Randomly select a list of data augmentation techniques to used based on aug_configs - """ - augmenters = list() - for aug_config in aug_configs: - if rand() > aug_config["proba"]: - continue - if aug_config["type"] == "dpi": - valid_factor = False - while not valid_factor: - factor = rand_uniform(aug_config["min_factor"], aug_config["max_factor"]) - valid_factor = not (("max_width" in aug_config and factor*img.size[0] > aug_config["max_width"]) or \ - ("max_height" in aug_config and factor * img.size[1] > aug_config["max_height"]) or \ - ("min_width" in aug_config and factor*img.size[0] < aug_config["min_width"]) or \ - ("min_height" in aug_config and factor * img.size[1] < aug_config["min_height"])) - augmenters.append(DPIAdjusting(factor, preserve_ratio=aug_config["preserve_ratio"])) - - elif aug_config["type"] == "zoom_ratio": - ratio_h = rand_uniform(aug_config["min_ratio_h"], aug_config["max_ratio_h"]) - ratio_w = rand_uniform(aug_config["min_ratio_w"], aug_config["max_ratio_w"]) - augmenters.append(ZoomRatio(ratio_h=ratio_h, ratio_w=ratio_w, keep_dim=aug_config["keep_dim"])) - - elif aug_config["type"] == "perspective": - scale = rand_uniform(aug_config["min_factor"], aug_config["max_factor"]) - augmenters.append(RandomPerspective(distortion_scale=scale, p=1, interpolation=InterpolationMode.BILINEAR, fill=fill_value)) - - elif aug_config["type"] == "elastic_distortion": - kernel_size = randint(aug_config["min_kernel_size"], aug_config["max_kernel_size"]) // 2 * 2 + 1 - sigma = rand_uniform(aug_config["min_sigma"], aug_config["max_sigma"]) - alpha= rand_uniform(aug_config["min_alpha"], aug_config["max_alpha"]) - augmenters.append(ElasticDistortion(kernel_size=(kernel_size, kernel_size), sigma=sigma, alpha=alpha)) - - elif aug_config["type"] == "dilation_erosion": - kernel_h = randint(aug_config["min_kernel"], aug_config["max_kernel"] + 1) - kernel_w = randint(aug_config["min_kernel"], aug_config["max_kernel"] + 1) - if randint(0, 2) == 0: - augmenters.append(Erosion((kernel_w, kernel_h), aug_config["iterations"])) - else: - augmenters.append(Dilation((kernel_w, kernel_h), aug_config["iterations"])) - - elif aug_config["type"] == "color_jittering": - augmenters.append(ColorJitter(contrast=aug_config["factor_contrast"], - brightness=aug_config["factor_brightness"], - saturation=aug_config["factor_saturation"], - hue=aug_config["factor_hue"], - )) - - elif aug_config["type"] == "gaussian_blur": - max_kernel_h = min(aug_config["max_kernel"], img.size[1]) - max_kernel_w = min(aug_config["max_kernel"], img.size[0]) - kernel_h = randint(aug_config["min_kernel"], max_kernel_h + 1) // 2 * 2 + 1 - kernel_w = randint(aug_config["min_kernel"], max_kernel_w + 1) // 2 * 2 + 1 - sigma = rand_uniform(aug_config["min_sigma"], aug_config["max_sigma"]) - augmenters.append(GaussianBlur(kernel_size=(kernel_w, kernel_h), sigma=sigma)) - - elif aug_config["type"] == "gaussian_noise": - augmenters.append(GaussianNoise(std=aug_config["std"])) - - elif aug_config["type"] == "sharpen": - alpha = rand_uniform(aug_config["min_alpha"], aug_config["max_alpha"]) - strength = rand_uniform(aug_config["min_strength"], aug_config["max_strength"]) - augmenters.append(Sharpen(alpha=alpha, strength=strength)) - - else: - print("Error - unknown augmentor: {}".format(aug_config["type"])) - exit(-1) - - return augmenters - - -def apply_data_augmentation(img, da_config): - """ - Apply data augmentation strategy on input image - """ - applied_da = list() - if da_config["proba"] != 1 and rand() > da_config["proba"]: - return img, applied_da - - # Convert to PIL Image - img = img[:, :, 0] if img.shape[2] == 1 else img - img = Image.fromarray(img) - - fill_value = da_config["fill_value"] if "fill_value" in da_config else 255 - augmenters = get_list_augmenters(img, da_config["augmentations"], fill_value=fill_value) - if da_config["order"] == "random": - random.shuffle(augmenters) - - for augmenter in augmenters: - img = augmenter(img) - applied_da.append(type(augmenter).__name__) - - # convert to numpy array - img = np.array(img) - img = np.expand_dims(img, axis=2) if len(img.shape) == 2 else img - return img, applied_da - - -def apply_transform(img, transform): - """ - Apply data augmentation technique on input image - """ - img = img[:, :, 0] if img.shape[2] == 1 else img - img = Image.fromarray(img) - img = transform(img) - img = np.array(img) - return np.expand_dims(img, axis=2) if len(img.shape) == 2 else img - - -def line_aug_config(proba_use_da, p): - return { - "order": "random", - "proba": proba_use_da, - "augmentations": [ - { - "type": "dpi", - "proba": p, - "min_factor": 0.5, - "max_factor": 1.5, - "preserve_ratio": True, - }, - { - "type": "perspective", - "proba": p, - "min_factor": 0, - "max_factor": 0.4, - }, - { - "type": "elastic_distortion", - "proba": p, - "min_alpha": 0.5, - "max_alpha": 1, - "min_sigma": 1, - "max_sigma": 10, - "min_kernel_size": 3, - "max_kernel_size": 9, - }, - { - "type": "dilation_erosion", - "proba": p, - "min_kernel": 1, - "max_kernel": 3, - "iterations": 1, - }, - { - "type": "color_jittering", - "proba": p, - "factor_hue": 0.2, - "factor_brightness": 0.4, - "factor_contrast": 0.4, - "factor_saturation": 0.4, - }, - { - "type": "gaussian_blur", - "proba": p, - "min_kernel": 3, - "max_kernel": 5, - "min_sigma": 3, - "max_sigma": 5, - }, - { - "type": "gaussian_noise", - "proba": p, - "std": 0.5, - }, - { - "type": "sharpen", - "proba": p, - "min_alpha": 0, - "max_alpha": 1, - "min_strength": 0, - "max_strength": 1, - }, - { - "type": "zoom_ratio", - "proba": p, - "min_ratio_h": 0.8, - "max_ratio_h": 1, - "min_ratio_w": 0.99, - "max_ratio_w": 1, - "keep_dim": True - }, - ] - } - - -def aug_config(proba_use_da, p): - return { - "order": "random", - "proba": proba_use_da, - "augmentations": [ - { - "type": "dpi", - "proba": p, - "min_factor": 0.75, - "max_factor": 1, - "preserve_ratio": True, - }, - { - "type": "perspective", - "proba": p, - "min_factor": 0, - "max_factor": 0.4, - }, - { - "type": "elastic_distortion", - "proba": p, - "min_alpha": 0.5, - "max_alpha": 1, - "min_sigma": 1, - "max_sigma": 10, - "min_kernel_size": 3, - "max_kernel_size": 9, - }, - { - "type": "dilation_erosion", - "proba": p, - "min_kernel": 1, - "max_kernel": 3, - "iterations": 1, - }, - { - "type": "color_jittering", - "proba": p, - "factor_hue": 0.2, - "factor_brightness": 0.4, - "factor_contrast": 0.4, - "factor_saturation": 0.4, - }, - { - "type": "gaussian_blur", - "proba": p, - "min_kernel": 3, - "max_kernel": 5, - "min_sigma": 3, - "max_sigma": 5, - }, - { - "type": "gaussian_noise", - "proba": p, - "std": 0.5, - }, - { - "type": "sharpen", - "proba": p, - "min_alpha": 0, - "max_alpha": 1, - "min_strength": 0, - "max_strength": 1, - }, - ] - } diff --git a/basic/utils.py b/basic/utils.py deleted file mode 100644 index ac50e57ab1136a2f2e1d209c39271406316d0479..0000000000000000000000000000000000000000 --- a/basic/utils.py +++ /dev/null @@ -1,179 +0,0 @@ - - -import numpy as np -import torch -from torch.distributions.uniform import Uniform -import cv2 - - -def randint(low, high): - """ - call torch.randint to preserve random among dataloader workers - """ - return int(torch.randint(low, high, (1, ))) - - -def rand(): - """ - call torch.rand to preserve random among dataloader workers - """ - return float(torch.rand((1, ))) - - -def rand_uniform(low, high): - """ - call torch uniform to preserve random among dataloader workers - """ - return float(Uniform(low, high).sample()) - - -def pad_sequences_1D(data, padding_value): - """ - Pad data with padding_value to get same length - """ - x_lengths = [len(x) for x in data] - longest_x = max(x_lengths) - padded_data = np.ones((len(data), longest_x)).astype(np.int32) * padding_value - for i, x_len in enumerate(x_lengths): - padded_data[i, :x_len] = data[i][:x_len] - return padded_data - - -def resize_max(img, max_width=None, max_height=None): - if max_width is not None and img.shape[1] > max_width: - ratio = max_width / img.shape[1] - new_h = int(np.floor(ratio * img.shape[0])) - new_w = int(np.floor(ratio * img.shape[1])) - img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) - if max_height is not None and img.shape[0] > max_height: - ratio = max_height / img.shape[0] - new_h = int(np.floor(ratio * img.shape[0])) - new_w = int(np.floor(ratio * img.shape[1])) - img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) - return img - - -def pad_images(data, padding_value, padding_mode="br"): - """ - data: list of numpy array - mode: "br"/"tl"/"random" (bottom-right, top-left, random) - """ - x_lengths = [x.shape[0] for x in data] - y_lengths = [x.shape[1] for x in data] - longest_x = max(x_lengths) - longest_y = max(y_lengths) - padded_data = np.ones((len(data), longest_x, longest_y, data[0].shape[2])) * padding_value - for i, xy_len in enumerate(zip(x_lengths, y_lengths)): - x_len, y_len = xy_len - if padding_mode == "br": - padded_data[i, :x_len, :y_len, ...] = data[i] - elif padding_mode == "tl": - padded_data[i, -x_len:, -y_len:, ...] = data[i] - elif padding_mode == "random": - xmax = longest_x - x_len - ymax = longest_y - y_len - xi = randint(0, xmax) if xmax >= 1 else 0 - yi = randint(0, ymax) if ymax >= 1 else 0 - padded_data[i, xi:xi+x_len, yi:yi+y_len, ...] = data[i] - else: - raise NotImplementedError("Undefined padding mode: {}".format(padding_mode)) - return padded_data - - -def pad_image(image, padding_value, new_height=None, new_width=None, pad_width=None, pad_height=None, padding_mode="br", return_position=False): - """ - data: list of numpy array - mode: "br"/"tl"/"random" (bottom-right, top-left, random) - """ - if pad_width is not None and new_width is not None: - raise NotImplementedError("pad_with and new_width are not compatible") - if pad_height is not None and new_height is not None: - raise NotImplementedError("pad_height and new_height are not compatible") - - h, w, c = image.shape - pad_width = pad_width if pad_width is not None else max(0, new_width - w) if new_width is not None else 0 - pad_height = pad_height if pad_height is not None else max(0, new_height - h) if new_height is not None else 0 - - if not (pad_width == 0 and pad_height == 0): - padded_image = np.ones((h+pad_height, w+pad_width, c)) * padding_value - if padding_mode == "br": - hi, wi = 0, 0 - elif padding_mode == "tl": - hi, wi = pad_height, pad_width - elif padding_mode == "random": - hi = randint(0, pad_height) if pad_height >= 1 else 0 - wi = randint(0, pad_width) if pad_width >= 1 else 0 - else: - raise NotImplementedError("Undefined padding mode: {}".format(padding_mode)) - padded_image[hi:hi + h, wi:wi + w, ...] = image - output = padded_image - else: - hi, wi = 0, 0 - output = image - - if return_position: - return output, [[hi, hi+h], [wi, wi+w]] - return output - - -def pad_image_width_right(img, new_width, padding_value): - """ - Pad img to right side with padding value to reach new_width as width - """ - h, w, c = img.shape - pad_width = max((new_width - w), 0) - pad_right = np.ones((h, pad_width, c), dtype=img.dtype) * padding_value - img = np.concatenate([img, pad_right], axis=1) - return img - - -def pad_image_width_left(img, new_width, padding_value): - """ - Pad img to left side with padding value to reach new_width as width - """ - h, w, c = img.shape - pad_width = max((new_width - w), 0) - pad_left = np.ones((h, pad_width, c), dtype=img.dtype) * padding_value - img = np.concatenate([pad_left, img], axis=1) - return img - - -def pad_image_width_random(img, new_width, padding_value, max_pad_left_ratio=1): - """ - Randomly pad img to left and right sides with padding value to reach new_width as width - """ - h, w, c = img.shape - pad_width = max((new_width - w), 0) - max_pad_left = int(max_pad_left_ratio*pad_width) - pad_left = randint(0, min(pad_width, max_pad_left)) if pad_width != 0 and max_pad_left > 0 else 0 - pad_right = pad_width - pad_left - pad_left = np.ones((h, pad_left, c), dtype=img.dtype) * padding_value - pad_right = np.ones((h, pad_right, c), dtype=img.dtype) * padding_value - img = np.concatenate([pad_left, img, pad_right], axis=1) - return img - - -def pad_image_height_random(img, new_height, padding_value, max_pad_top_ratio=1): - """ - Randomly pad img top and bottom sides with padding value to reach new_width as width - """ - h, w, c = img.shape - pad_height = max((new_height - h), 0) - max_pad_top = int(max_pad_top_ratio*pad_height) - pad_top = randint(0, min(pad_height, max_pad_top)) if pad_height != 0 and max_pad_top > 0 else 0 - pad_bottom = pad_height - pad_top - pad_top = np.ones((pad_top, w, c), dtype=img.dtype) * padding_value - pad_bottom = np.ones((pad_bottom, w, c), dtype=img.dtype) * padding_value - img = np.concatenate([pad_top, img, pad_bottom], axis=0) - return img - - -def pad_image_height_bottom(img, new_height, padding_value): - """ - Pad img to bottom side with padding value to reach new_height as height - """ - h, w, c = img.shape - pad_height = max((new_height - h), 0) - pad_bottom = np.ones((pad_height, w, c)) * padding_value - img = np.concatenate([img, pad_bottom], axis=0) - return img diff --git a/prediction-requirements.txt b/prediction-requirements.txt index 1fe32d37816bfb6c2456ebe2bb7b9858ffef49e9..86bfaeb8bf1dcef9ad52f4c64507482f5d942024 100644 --- a/prediction-requirements.txt +++ b/prediction-requirements.txt @@ -1,4 +1,8 @@ -numpy==1.22.3 -opencv-python==4.5.5.64 -PyYAML==6.0 -torch==1.11.0 +arkindex-client==1.0.11 +editdistance==0.6.0 +fontTools==4.29.1 +imageio==2.16.0 +networkx==2.6.3 +tensorboard==0.2.1 +torchvision==0.12.0 +tqdm==4.62.3 diff --git a/requirements.txt b/requirements.txt index 06c4835601286c1154ad24641bf4e102babf8fe1..1fe32d37816bfb6c2456ebe2bb7b9858ffef49e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,81 +1,4 @@ -absl-py==1.0.0 -backports.cached-property==1.0.1 -cachetools==5.0.0 -certifi==2021.10.8 -charset-normalizer==2.0.12 -click==8.0.4 -click-default-group==1.2.2 -cloup==0.7.1 -colorama==0.4.4 -colour==0.1.5 -commonmark==0.9.1 -cycler==0.11.0 -decorator==5.1.1 -EasyProcess==1.1 -editdistance==0.6.0 -entrypoint2==1.0 -fonttools==4.29.1 -glcontext==2.3.4 -google-auth==2.6.0 -google-auth-oauthlib==0.4.6 -grpcio==1.44.0 -idna==3.3 -imageio==2.16.0 -importlib-metadata==4.11.1 -isosurfaces==0.1.0 -joblib==1.1.0 -kiwisolver==1.3.2 -manim==0.15.0 -manimlib==0.2.0 -ManimPango==0.4.0.post2 -mapbox-earcut==0.12.11 -Markdown==3.3.6 -masked-norm==0.0.0 -matplotlib==3.5.1 -moderngl==5.6.4 -moderngl-window==2.4.1 -multipledispatch==0.6.0 -networkx==2.6.3 -numpy==1.22.2 -oauthlib==3.2.0 -opencv-python==4.5.5.62 -packaging==21.3 -Pillow==9.0.1 -pkg_resources==0.0.0 -progressbar==2.5 -protobuf==3.19.4 -pyasn1==0.4.8 -pyasn1-modules==0.2.8 -pycairo==1.20.1 -pydub==0.25.1 -pyglet==1.5.21 -Pygments==2.11.2 -pyparsing==3.0.7 -pyrr==0.10.3 -python-dateutil==2.8.2 -pyunpack==0.2.2 -PyWavelets==1.2.0 -requests==2.27.1 -requests-oauthlib==1.3.1 -rich==11.2.0 -rsa==4.8 -scikit-image==0.19.2 -scikit-learn==1.0.2 -scipy==1.8.0 -screeninfo==0.8 -six==1.16.0 -skia-pathops==0.7.2 -srt==3.5.1 -tensorboard==2.8.0 -tensorboard-data-server==0.6.1 -tensorboard-plugin-wit==1.8.1 -threadpoolctl==3.1.0 -tifffile==2022.2.9 -torch==1.8.1 -torchvision==0.9.1 -tqdm==4.62.3 -typing_extensions==4.1.1 -urllib3==1.26.8 -watchdog==2.1.6 -Werkzeug==2.0.3 -zipp==3.7.0 +numpy==1.22.3 +opencv-python==4.5.5.64 +PyYAML==6.0 +torch==1.11.0 diff --git a/setup.py b/setup.py index bb7648588ce4b0c76de9cc0ed2351a18c9a91663..a0002964d8cb9cbb10e2a5ec8a3b3b62cfcc384b 100755 --- a/setup.py +++ b/setup.py @@ -6,8 +6,8 @@ from pathlib import Path from setuptools import find_packages, setup -def parse_requirements(): - path = Path(__file__).parent.resolve() / "prediction-requirements.txt" +def parse_requirements(filename): + path = Path(__file__).parent.resolve() / filename assert path.exists(), f"Missing requirements: {path}" return list(map(str.strip, path.read_text().splitlines())) @@ -21,6 +21,12 @@ setup( author="Teklia", author_email="contact@teklia.com", url="https://gitlab.com/teklia/dan", - install_requires=parse_requirements(), + install_requires=parse_requirements("requirements.txt"), packages=find_packages(), + entry_points={ + "console_scripts": [ + "teklia-dan=dan.cli:main", + ] + }, + extras_require={"predict": parse_requirements("prediction-requirements.txt")}, ) diff --git a/visual.png b/visual.png deleted file mode 100644 index e228e7924cb2760dc61636651b525ef5359408c5..0000000000000000000000000000000000000000 Binary files a/visual.png and /dev/null differ diff --git a/visual_slanted_lines.png b/visual_slanted_lines.png deleted file mode 100644 index 478af3ac1eb40316bbb1f0e7400f592b0d31ec70..0000000000000000000000000000000000000000 Binary files a/visual_slanted_lines.png and /dev/null differ