diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 09a6111c0d86921e36d749493e3ea43749109fbd..3703391f03074a9efb507832eeba643a8e901b62 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,3 @@
-# Only run on our the DAN python module
-files: '^dan'
-
 repos:
   - repo: https://github.com/PyCQA/isort
     rev: 5.10.1
diff --git a/Datasets/dataset_formatters/extract_from_arkindex.py b/Datasets/dataset_formatters/extract_from_arkindex.py
deleted file mode 100644
index 4bad36391d46e25a0f599362311ecef9c27f2da7..0000000000000000000000000000000000000000
--- a/Datasets/dataset_formatters/extract_from_arkindex.py
+++ /dev/null
@@ -1,113 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-#
-# Example of minimal usage:
-# python extract_from_arkindex.py
-#   --corpus "AN-Simara annotations E (2022-06-20)"
-#   --parents-types folder
-#   --parents-names FRAN_IR_032031_4538.pdf
-#   --output-dir ../
-
-"""
-    The extraction module
-    ======================
-"""
-
-import logging
-import os
-
-import cv2
-import imageio.v2 as iio
-from arkindex import ArkindexClient, options_from_env
-from tqdm import tqdm
-
-from extraction_utils import arkindex_utils as ark
-from extraction_utils import utils
-
-logging.basicConfig(
-    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
-)
-
-
-IMAGES_DIR = "./images/"  # Path to the images directory.
-LABELS_DIR = "./labels/"  # Path to the labels directory.
-
-# Layout string to token
-SEM_MATCHING_TOKENS_STR = {
-    "INTITULE": "ⓘ",
-    "DATE": "â““",
-    "COTE_SERIE": "â“¢",
-    "ANALYSE_COMPL.": "â“’",
-    "PRECISIONS_SUR_COTE": "â“Ÿ",
-    "COTE_ARTICLE": "ⓐ"
-}
-
-# Layout begin-token to end-token
-SEM_MATCHING_TOKENS = {
-    "ⓘ": "Ⓘ",
-    "â““": "â’¹",
-    "ⓢ": "Ⓢ",
-    "â“’": "â’¸",
-    "â“Ÿ": "â“…",
-    "ⓐ": "Ⓐ"
-}
-
-
-if __name__ == '__main__':
-    args = utils.get_cli_args()
-
-    # Get and initialize the parameters.
-    os.makedirs(IMAGES_DIR, exist_ok=True)
-    os.makedirs(LABELS_DIR, exist_ok=True)
-
-    # Login to arkindex.
-    client = ArkindexClient(**options_from_env())
-
-    corpus = ark.retrieve_corpus(client, args.corpus)
-    subsets = ark.retrieve_subsets(
-        client, corpus, args.parents_types, args.parents_names
-    )
-
-    # Iterate over the subsets to find the page images and labels.
-    for subset in subsets:
-
-        os.makedirs(os.path.join(args.output_dir, IMAGES_DIR, subset["name"]), exist_ok=True)
-        os.makedirs(os.path.join(args.output_dir, LABELS_DIR, subset["name"]), exist_ok=True)
-
-        for page in tqdm(
-            client.paginate(
-                "ListElementChildren", id=subset["id"], type="page", recursive=True
-            ),
-            desc="Set " + subset["name"],
-        ):
-
-            image = iio.imread(page["zone"]["url"])
-            cv2.imwrite(
-                os.path.join(args.output_dir, IMAGES_DIR, subset['name'], f"{page['id']}.jpg"),
-                cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
-            )
-
-            tr = client.request('ListTranscriptions', id=page['id'], worker_version=None)['results']
-            tr = [one for one in tr if one['worker_version_id'] is None]
-            assert len(tr) == 1, page['id']
-
-            for one_tr in tr:
-                ent = client.request('ListTranscriptionEntities', id=one_tr['id'])['results']
-                ent = [one for one in ent if one['worker_version_id'] is None]
-                if len(ent) == 0:
-                    continue
-                else:
-                    text = one_tr['text']
-
-            new_text = text
-            count = 0
-            for e in ent:
-                start_token = SEM_MATCHING_TOKENS_STR[e['entity']['metas']['subtype']]
-                end_token = SEM_MATCHING_TOKENS[start_token]
-                new_text = new_text[:count+e['offset']] + start_token + new_text[count+e['offset']:]
-                count += 1
-                new_text = new_text[:count+e['offset']+e['length']] + end_token + new_text[count+e['offset']+e['length']:]
-                count += 1
-
-            with open(os.path.join(args.output_dir, LABELS_DIR, subset['name'], f"{page['id']}.txt"), 'w') as f:
-                f.write(new_text)
diff --git a/Datasets/dataset_formatters/extraction_utils/arkindex_utils.py b/Datasets/dataset_formatters/extraction_utils/arkindex_utils.py
deleted file mode 100644
index 5216a7e0dd286cdfb53c73802324cd40cd07f7bd..0000000000000000000000000000000000000000
--- a/Datasets/dataset_formatters/extraction_utils/arkindex_utils.py
+++ /dev/null
@@ -1,67 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-"""
-    The arkindex_utils module
-    ======================
-"""
-
-import errno
-import logging
-import sys
-
-from apistar.exceptions import ErrorResponse
-
-
-def retrieve_corpus(client, corpus_name: str) -> str:
-    """
-    Retrieve the corpus id from the corpus name.
-    :param client: The arkindex client.
-    :param corpus_name: The name of the corpus to retrieve.
-    :return target_corpus: The id of the retrieved corpus.
-    """
-    for corpus in client.request("ListCorpus"):
-        if corpus["name"] == corpus_name:
-            target_corpus = corpus["id"]
-    try:
-        logging.info(f"Corpus id retrieved: {target_corpus}")
-    except NameError:
-        logging.error(f"Corpus {corpus_name} not found")
-        sys.exit(errno.EINVAL)
-
-    return target_corpus
-
-
-def retrieve_subsets(
-    client, corpus: str, parents_types: list, parents_names: list
-) -> list:
-    """
-    Retrieve the requested subsets.
-    :param client: The arkindex client.
-    :param corpus: The id of the retrieved corpus.
-    :param parents_types: The types of parents of the elements to retrieve.
-    :param parents_names: The names of parents of the elements to retrieve.
-    :return subsets: The retrieved subsets.
-    """
-    subsets = []
-    for parent_type in parents_types:
-        try:
-            subsets.extend(
-                client.request("ListElements", corpus=corpus, type=parent_type)[
-                    "results"
-                ]
-            )
-        except ErrorResponse as e:
-            logging.error(f"{e.content}: {parent_type}")
-            sys.exit(errno.EINVAL)
-    # Retrieve subsets with name in parents-names. If no parents-names given, keep all subsets.
-    if parents_names is not None:
-        logging.info(f"Retrieving {parents_names} subset(s)")
-        subsets = [subset for subset in subsets if subset["name"] in parents_names]
-    else:
-        logging.info("Retrieving all subsets")
-
-    if len(subsets) == 0:
-        logging.info("No subset found")
-
-    return subsets
diff --git a/Datasets/dataset_formatters/extraction_utils/utils.py b/Datasets/dataset_formatters/extraction_utils/utils.py
deleted file mode 100644
index fad5cdec84a0b68676804165a0606bbc63c8c7f9..0000000000000000000000000000000000000000
--- a/Datasets/dataset_formatters/extraction_utils/utils.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-"""
-    The utils module
-    ======================
-"""
-
-import argparse
-
-def get_cli_args():
-    """
-    Get the command-line arguments.
-    :return: The command-line arguments.
-    """
-    parser = argparse.ArgumentParser(
-        description="Arkindex DAN Training Label Generation"
-    )
-
-    # Required arguments.
-    parser.add_argument(
-        "--corpus",
-        type=str,
-        help="Name of the corpus from which the data will be retrieved.",
-        required=True,
-    )
-    parser.add_argument(
-        "--parents-types",
-        nargs="+",
-        type=str,
-        help="Type of parents of the elements.",
-        required=True,
-    )
-    parser.add_argument(
-        "--output-dir",
-        type=str,
-        help="Path to the output directory.",
-        required=True,
-    )
-
-    # Optional arguments.
-    parser.add_argument(
-        "--parents-names",
-        nargs="+",
-        type=str,
-        help="Names of parents of the elements.",
-        default=None,
-    )
-    return parser.parse_args()
diff --git a/Datasets/dataset_formatters/generic_dataset_formatter.py b/Datasets/dataset_formatters/generic_dataset_formatter.py
deleted file mode 100644
index ae55af026c841af4a9ab8ee08e3437bf9d16c034..0000000000000000000000000000000000000000
--- a/Datasets/dataset_formatters/generic_dataset_formatter.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import os
-import shutil
-import tarfile
-import pickle
-import re
-from PIL import Image
-import numpy as np
-
-
-class DatasetFormatter:
-    """
-    Global pipeline/functions for dataset formatting
-    """
-
-    def __init__(self, dataset_name, level, extra_name="", set_names=["train", "valid", "test"]):
-        self.dataset_name = dataset_name
-        self.level = level
-        self.set_names = set_names
-        self.target_fold_path = os.path.join(
-            "Datasets", "formatted", "{}_{}{}".format(dataset_name, level, extra_name))
-        self.map_datasets_files = dict()
-        self.extract_with_dirname = False
-
-    def format(self):
-        self.init_format()
-        self.map_datasets_files[self.dataset_name][self.level]["format_function"]()
-        self.end_format()
-
-    def init_format(self):
-        """
-        Load and extracts needed files
-        """
-        os.makedirs(self.target_fold_path, exist_ok=True)
-
-        for set_name in self.set_names:
-            os.makedirs(os.path.join(self.target_fold_path, set_name), exist_ok=True)
-
-
-class OCRDatasetFormatter(DatasetFormatter):
-    """
-    Specific pipeline/functions for OCR/HTR dataset formatting
-    """
-
-    def __init__(self, source_dataset, level, extra_name="", set_names=["train", "valid", "test"]):
-        super(OCRDatasetFormatter, self).__init__(source_dataset, level, extra_name, set_names)
-        self.charset = set()
-        self.gt = dict()
-        for set_name in set_names:
-            self.gt[set_name] = dict()
-
-    def format_text_label(self, label):
-        """
-        Remove extra space or line break characters
-        """
-        temp = re.sub("(\n)+", '\n', label)
-        return re.sub("( )+", ' ', temp).strip(" \n")
-
-    def load_resize_save(self, source_path, target_path):
-        """
-        Load image, apply resolution modification and save it
-        """
-        shutil.copyfile(source_path, target_path)
-
-    def resize(self, img, source_dpi, target_dpi):
-        """
-        Apply resolution modification to image
-        """
-        if source_dpi == target_dpi:
-            return img
-        if isinstance(img, np.ndarray):
-            h, w = img.shape[:2]
-            img = Image.fromarray(img)
-        else:
-            w, h = img.size
-        ratio = target_dpi / source_dpi
-        img = img.resize((int(w*ratio), int(h*ratio)), Image.BILINEAR)
-        return np.array(img)
-
-    def end_format(self):
-        """
-        Save label and charset files
-        """
-        with open(os.path.join(self.target_fold_path, "labels.pkl"), "wb") as f:
-            pickle.dump({
-                "ground_truth": self.gt,
-                "charset": sorted(list(self.charset)),
-            }, f)
-        with open(os.path.join(self.target_fold_path, "charset.pkl"), "wb") as f:
-            pickle.dump(sorted(list(self.charset)), f)
diff --git a/Datasets/dataset_formatters/simara_formatter.py b/Datasets/dataset_formatters/simara_formatter.py
deleted file mode 100644
index c60eb30f94cf569d3b2d9d9197a16756ee994c01..0000000000000000000000000000000000000000
--- a/Datasets/dataset_formatters/simara_formatter.py
+++ /dev/null
@@ -1,104 +0,0 @@
-from Datasets.dataset_formatters.generic_dataset_formatter import OCRDatasetFormatter
-import os
-import numpy as np
-from Datasets.dataset_formatters.utils_dataset import natural_sort
-from PIL import Image
-import xml.etree.ElementTree as ET
-import re
-from tqdm import tqdm
-
-# Layout string to token
-SEM_MATCHING_TOKENS_STR = {
-    "INTITULE": "ⓘ",
-    "DATE": "â““",
-    "COTE_SERIE": "â“¢",
-    "ANALYSE_COMPL": "â“’",
-    "PRECISIONS_SUR_COTE": "â“Ÿ",
-    "COTE_ARTICLE": "ⓐ"
-}
-
-# Layout begin-token to end-token
-SEM_MATCHING_TOKENS = {
-    "ⓘ": "Ⓘ",
-    "â““": "â’¹",
-    "ⓢ": "Ⓢ",
-    "â“’": "â’¸",
-    "â“Ÿ": "â“…",
-    "ⓐ": "Ⓐ"
-}
-
-class SimaraDatasetFormatter(OCRDatasetFormatter):
-    def __init__(self, level, set_names=["train", "valid", "test"], dpi=150, sem_token=True):
-        super(SimaraDatasetFormatter, self).__init__("simara", level, "_sem" if sem_token else "", set_names)
-
-        self.dpi = dpi
-        self.sem_token = sem_token
-        self.map_datasets_files.update({
-            "simara": {
-                # (1,050 for train, 100 for validation and 100 for test)
-                "page": {
-                    "format_function": self.format_simara_page,
-                },
-            }
-        })
-        self.matching_tokens_str = SEM_MATCHING_TOKENS_STR
-        self.matching_tokens = SEM_MATCHING_TOKENS
-
-    def preformat_simara_page(self):
-        """
-        Extract all information from dataset and correct some annotations
-        """
-        dataset = {
-            "train": list(),
-            "valid": list(),
-            "test": list()
-        }
-        img_folder_path = os.path.join("Datasets", "raw", "simara", "images")
-        labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels")
-        sem_labels_folder_path = os.path.join("Datasets", "raw", "simara", "labels_sem")
-        train_files = [
-            os.path.join(labels_folder_path, 'train', name)
-            for name in os.listdir(os.path.join(sem_labels_folder_path, 'train'))]
-        valid_files = [
-            os.path.join(labels_folder_path, 'valid', name)
-            for name in os.listdir(os.path.join(sem_labels_folder_path, 'valid'))]
-        test_files = [
-            os.path.join(labels_folder_path, 'test', name)
-            for name in os.listdir(os.path.join(sem_labels_folder_path, 'test'))]
-        for set_name, files in zip(self.set_names, [train_files, valid_files, test_files]):
-            for i, label_file in enumerate(tqdm(files, desc='Pre-formatting '+set_name)):
-                with open(label_file, 'r') as f:
-                    text = f.read()
-                with open(label_file.replace('labels', 'labels_sem'), 'r') as f:
-                    sem_text = f.read()
-                dataset[set_name].append({
-                    "img_path": os.path.join(
-                        img_folder_path, set_name, label_file.split('/')[-1].replace('txt', 'jpg')),
-                    "label": text,
-                    "sem_label": sem_text,
-                })
-        print(dataset['test'], len(dataset['test']))
-        return dataset
-
-    def format_simara_page(self):
-        """
-        Format simara page dataset
-        """
-        dataset = self.preformat_simara_page()
-        for set_name in self.set_names:
-            fold = os.path.join(self.target_fold_path, set_name)
-            for sample in tqdm(dataset[set_name], desc='Formatting '+set_name):
-                new_name = sample['img_path'].split('/')[-1]
-                new_img_path = os.path.join(fold, new_name)
-                self.load_resize_save(sample["img_path"], new_img_path)#, 300, self.dpi)
-                page = {
-                    "text": sample["label"] if not self.sem_token else sample["sem_label"],
-                }
-                self.charset = self.charset.union(set(page["text"]))
-                self.gt[set_name][new_name] = page
-
-
-if __name__ == "__main__":
-
-    SimaraDatasetFormatter("page", sem_token=True).format()
-    #SimaraDatasetFormatter("page", sem_token=False).format()
diff --git a/Datasets/dataset_formatters/utils_dataset.py b/Datasets/dataset_formatters/utils_dataset.py
deleted file mode 100644
index 962274da3a806f729904d839481f294e1d5d6ca2..0000000000000000000000000000000000000000
--- a/Datasets/dataset_formatters/utils_dataset.py
+++ /dev/null
@@ -1,7 +0,0 @@
-import re
-
-
-def natural_sort(l):
-    convert = lambda text: int(text) if text.isdigit() else text.lower()
-    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key)]
-    return sorted(l, key=alphanum_key)
\ No newline at end of file
diff --git a/OCR/document_OCR/dan/main_dan.py b/OCR/document_OCR/dan/main_dan.py
deleted file mode 100644
index 2a15f0aa8ee4ea4f4b9bd9297a74180827ce71a5..0000000000000000000000000000000000000000
--- a/OCR/document_OCR/dan/main_dan.py
+++ /dev/null
@@ -1,207 +0,0 @@
-import os
-import sys
-DOSSIER_COURRANT = os.path.dirname(os.path.abspath(__file__))
-DOSSIER_PARENT = os.path.dirname(DOSSIER_COURRANT)
-sys.path.append(os.path.dirname(DOSSIER_PARENT))
-sys.path.append(os.path.dirname(os.path.dirname(DOSSIER_PARENT)))
-sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(DOSSIER_PARENT))))
-from torch.optim import Adam
-from basic.transforms import aug_config
-from OCR.ocr_dataset_manager import OCRDataset, OCRDatasetManager
-from OCR.document_OCR.dan.trainer_dan import Manager
-from dan.decoder import GlobalHTADecoder
-from dan.models import FCN_Encoder
-from basic.scheduler import exponential_dropout_scheduler, linear_scheduler
-import torch
-import numpy as np
-import random
-import torch.multiprocessing as mp
-
-
-def train_and_test(rank, params):
-    torch.manual_seed(0)
-    torch.cuda.manual_seed(0)
-    np.random.seed(0)
-    random.seed(0)
-    torch.backends.cudnn.benchmark = False
-    torch.backends.cudnn.deterministic = True
-
-    params["training_params"]["ddp_rank"] = rank
-    model = Manager(params)
-    model.load_model()
-
-    model.train()
-
-    # load weights giving best CER on valid set
-    model.params["training_params"]["load_epoch"] = "best"
-    model.load_model()
-
-    metrics = ["cer", "wer", "time", "map_cer",  "loer"]
-    for dataset_name in params["dataset_params"]["datasets"].keys():
-        for set_name in ["test", "valid", "train"]:
-            model.predict("{}-{}".format(dataset_name, set_name), [(dataset_name, set_name), ], metrics, output=True)
-
-
-if __name__ == "__main__":
-
-    dataset_name = "simara"  # ["RIMES", "READ_2016"]
-    dataset_level = "page"  # ["page", "double_page"]
-    dataset_variant = "_sem"
-
-    # max number of lines for synthetic documents
-    max_nb_lines = {
-        "RIMES": 40,
-        "READ_2016": 30,
-    }
-
-    params = {
-        "dataset_params": {
-            "dataset_manager": OCRDatasetManager,
-            "dataset_class": OCRDataset,
-            "datasets": {
-                dataset_name: "../../../Datasets/formatted/{}_{}{}".format(dataset_name, dataset_level, dataset_variant),
-            },
-            "train": {
-                "name": "{}-train".format(dataset_name),
-                "datasets": [(dataset_name, "train"), ],
-            },
-            "valid": {
-                "{}-valid".format(dataset_name): [(dataset_name, "valid"), ],
-            },
-            "config": {
-                "load_in_memory": True,  # Load all images in CPU memory
-                "worker_per_gpu": 4,  # Num of parallel processes per gpu for data loading
-                "width_divisor": 8,  # Image width will be divided by 8
-                "height_divisor": 32,  # Image height will be divided by 32
-                "padding_value": 0,  # Image padding value
-                "padding_token": None,  # Label padding value
-                "charset_mode": "seq2seq",  # add end-of-transcription ans start-of-transcription tokens to charset
-                "constraints": ["add_eot", "add_sot"],  # add end-of-transcription ans start-of-transcription tokens in labels
-                "normalize": True,  # Normalize with mean and variance of training dataset
-                "preprocessings": [
-                    {
-                        "type": "to_RGB",
-                        # if grayscaled image, produce RGB one (3 channels with same value) otherwise do nothing
-                    },
-                ],
-                "augmentation": aug_config(0.9, 0.1),
-                "synthetic_data": None,
-                #"synthetic_data": {
-                #    "init_proba": 0.9,  # begin proba to generate synthetic document
-                #    "end_proba": 0.2,  # end proba to generate synthetic document
-                #    "num_steps_proba": 200000,  # linearly decrease the percent of synthetic document from 90% to 20% through 200000 samples
-                #    "proba_scheduler_function": linear_scheduler,  # decrease proba rate linearly
-                #    "start_scheduler_at_max_line": True,  # start decreasing proba only after curriculum reach max number of lines
-                #    "dataset_level": dataset_level,
-                #    "curriculum": True,  # use curriculum learning (slowly increase number of lines per synthetic samples)
-                #    "crop_curriculum": True,  # during curriculum learning, crop images under the last text line
-                #    "curr_start": 0,  # start curriculum at iteration
-                #    "curr_step": 10000,  # interval to increase the number of lines for curriculum learning
-                #    "min_nb_lines": 1,  # initial number of lines for curriculum learning
-                #    "max_nb_lines": max_nb_lines[dataset_name],  # maximum number of lines for curriculum learning
-                #    "padding_value": 255,
-                #    # config for synthetic line generation
-                #    "config": {
-                #        "background_color_default": (255, 255, 255),
-                #        "background_color_eps": 15,
-                #        "text_color_default": (0, 0, 0),
-                #        "text_color_eps": 15,
-                #        "font_size_min": 35,
-                #        "font_size_max": 45,
-                #        "color_mode": "RGB",
-                #        "padding_left_ratio_min": 0.00,
-                #        "padding_left_ratio_max": 0.05,
-                #        "padding_right_ratio_min": 0.02,
-                #        "padding_right_ratio_max": 0.2,
-                #        "padding_top_ratio_min": 0.02,
-                #        "padding_top_ratio_max": 0.1,
-                #        "padding_bottom_ratio_min": 0.02,
-                #        "padding_bottom_ratio_max": 0.1,
-                #    },
-                #}
-            }
-        },
-
-        "model_params": {
-            "models": {
-                "encoder": FCN_Encoder,
-                "decoder": GlobalHTADecoder,
-            },
-            #"transfer_learning": None,
-            "transfer_learning": {
-               # model_name: [state_dict_name, checkpoint_path, learnable, strict]
-                "encoder": ["encoder", "dan_rimes_page.pt", True, True],
-                "decoder": ["decoder", "dan_rimes_page.pt", True, False],
-            },
-            "transfered_charset": True,  # Transfer learning of the decision layer based on charset of the line HTR model
-            "additional_tokens": 1,  # for decision layer = [<eot>, ], only for transfered charset
-
-            "input_channels": 3,  # number of channels of input image
-            "dropout": 0.5,  # dropout rate for encoder
-            "enc_dim": 256,  # dimension of extracted features
-            "nb_layers": 5,  # encoder
-            "h_max": 500,  # maximum height for encoder output (for 2D positional embedding)
-            "w_max": 1000,  # maximum width for encoder output (for 2D positional embedding)
-            "l_max": 15000,  # max predicted sequence (for 1D positional embedding)
-            "dec_num_layers": 8,  # number of transformer decoder layers
-            "dec_num_heads": 4,  # number of heads in transformer decoder layers
-            "dec_res_dropout": 0.1,  # dropout in transformer decoder layers
-            "dec_pred_dropout": 0.1,  # dropout rate before decision layer
-            "dec_att_dropout": 0.1,  # dropout rate in multi head attention
-            "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
-            "use_2d_pe": True,  # use 2D positional embedding
-            "use_1d_pe": True,  # use 1D positional embedding
-            "use_lstm": False,
-            "attention_win": 100,  # length of attention window
-            # Curriculum dropout
-            "dropout_scheduler": {
-                "function": exponential_dropout_scheduler,
-                "T": 5e4,
-            }
-
-        },
-
-        "training_params": {
-            "output_folder": "dan_simara_page",  # folder name for checkpoint and results
-            "max_nb_epochs": 50000,  # maximum number of epochs before to stop
-            "max_training_time": 3600 * 24 * 1.9,  # maximum time before to stop (in seconds)
-            "load_epoch": "last",  # ["best", "last"]: last to continue training, best to evaluate
-            "interval_save_weights": None,  # None: keep best and last only
-            "batch_size": 2,  # mini-batch size for training
-            "valid_batch_size": 4,  # mini-batch size for valdiation
-            "use_ddp": False,  # Use DistributedDataParallel
-            "ddp_port": "20027",
-            "use_amp": True,  # Enable automatic mix-precision
-            "nb_gpu": torch.cuda.device_count(),
-            "optimizers": {
-                "all": {
-                    "class": Adam,
-                    "args": {
-                        "lr": 0.0001,
-                        "amsgrad": False,
-                    }
-                },
-            },
-            "lr_schedulers": None,  # Learning rate schedulers
-            "eval_on_valid": True,  # Whether to eval and logs metrics on validation set during training or not
-            "eval_on_valid_interval": 5,  # Interval (in epochs) to evaluate during training
-            "focus_metric": "cer",  # Metrics to focus on to determine best epoch
-            "expected_metric_value": "low",  # ["high", "low"] What is best for the focus metric value
-            "set_name_focus_metric": "{}-valid".format(dataset_name),  # Which dataset to focus on to select best weights
-            "train_metrics": ["loss_ce", "cer", "wer", "syn_max_lines"],  # Metrics name for training
-            "eval_metrics": ["cer", "wer", "map_cer"],  # Metrics name for evaluation on validation set during training
-            "force_cpu": False,  # True for debug purposes
-            "max_char_prediction": 1000,  # max number of token prediction
-            # Keep teacher forcing rate to 20% during whole training
-            "teacher_forcing_scheduler": {
-                "min_error_rate": 0.2,
-                "max_error_rate": 0.2,
-                "total_num_steps": 5e4
-            },
-        },
-    }
-
-    if params["training_params"]["use_ddp"] and not params["training_params"]["force_cpu"]:
-        mp.spawn(train_and_test, args=(params,), nprocs=params["training_params"]["nb_gpu"])
-    else:
-        train_and_test(0, params)
diff --git a/OCR/document_OCR/dan/trainer_dan.py b/OCR/document_OCR/dan/trainer_dan.py
deleted file mode 100644
index d8fd84324e88ff44e2983abfdcf53ac70709c08b..0000000000000000000000000000000000000000
--- a/OCR/document_OCR/dan/trainer_dan.py
+++ /dev/null
@@ -1,165 +0,0 @@
-from OCR.ocr_manager import OCRManager
-from torch.nn import CrossEntropyLoss
-import torch
-from dan.ocr_utils import LM_ind_to_str
-import numpy as np
-from torch.cuda.amp import autocast
-import time
-
-
-class Manager(OCRManager):
-
-    def __init__(self, params):
-        super(Manager, self).__init__(params)
-
-    def load_save_info(self, info_dict):
-        if "curriculum_config" in info_dict.keys():
-            if self.dataset.train_dataset is not None:
-                self.dataset.train_dataset.curriculum_config = info_dict["curriculum_config"]
-
-    def add_save_info(self, info_dict):
-        info_dict["curriculum_config"] = self.dataset.train_dataset.curriculum_config
-        return info_dict
-
-    def get_init_hidden(self, batch_size):
-        num_layers = 1
-        hidden_size = self.params["model_params"]["enc_dim"]
-        return torch.zeros(num_layers, batch_size, hidden_size), torch.zeros(num_layers, batch_size, hidden_size)
-
-    def apply_teacher_forcing(self, y, y_len, error_rate):
-        y_error = y.clone()
-        for b in range(len(y_len)):
-            for i in range(1, y_len[b]):
-                if np.random.rand() < error_rate and y[b][i] != self.dataset.tokens["pad"]:
-                    y_error[b][i] = np.random.randint(0, len(self.dataset.charset)+2)
-        return y_error, y_len
-
-    def train_batch(self, batch_data, metric_names):
-        loss_func = CrossEntropyLoss(ignore_index=self.dataset.tokens["pad"])
-
-        sum_loss = 0
-        x = batch_data["imgs"].to(self.device)
-        y = batch_data["labels"].to(self.device)
-        reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]]
-        y_len = batch_data["labels_len"]
-
-        # add errors in teacher forcing
-        if "teacher_forcing_error_rate" in self.params["training_params"] and self.params["training_params"]["teacher_forcing_error_rate"] is not None:
-            error_rate = self.params["training_params"]["teacher_forcing_error_rate"]
-            simulated_y_pred, y_len = self.apply_teacher_forcing(y, y_len, error_rate)
-        elif "teacher_forcing_scheduler" in self.params["training_params"]:
-            error_rate = self.params["training_params"]["teacher_forcing_scheduler"]["min_error_rate"] + min(self.latest_step, self.params["training_params"]["teacher_forcing_scheduler"]["total_num_steps"]) * (self.params["training_params"]["teacher_forcing_scheduler"]["max_error_rate"]-self.params["training_params"]["teacher_forcing_scheduler"]["min_error_rate"]) / self.params["training_params"]["teacher_forcing_scheduler"]["total_num_steps"]
-            simulated_y_pred, y_len = self.apply_teacher_forcing(y, y_len, error_rate)
-        else:
-            simulated_y_pred = y
-
-        with autocast(enabled=self.params["training_params"]["use_amp"]):
-            hidden_predict = None
-            cache = None
-
-            raw_features = self.models["encoder"](x)
-            features_size = raw_features.size()
-            b, c, h, w = features_size
-
-            pos_features = self.models["decoder"].features_updater.get_pos_features(raw_features)
-            features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(2, 0, 1)
-            enhanced_features = pos_features
-            enhanced_features = torch.flatten(enhanced_features, start_dim=2, end_dim=3).permute(2, 0, 1)
-            output, pred, hidden_predict, cache, weights = self.models["decoder"](features, enhanced_features,
-                                                                               simulated_y_pred[:, :-1],
-                                                                               reduced_size,
-                                                                               [max(y_len) for _ in range(b)],
-                                                                               features_size,
-                                                                               start=0,
-                                                                               hidden_predict=hidden_predict,
-                                                                               cache=cache,
-                                                                               keep_all_weights=True)
-
-            loss_ce = loss_func(pred, y[:, 1:])
-            sum_loss += loss_ce
-            with autocast(enabled=False):
-                self.backward_loss(sum_loss)
-                self.step_optimizers()
-                self.zero_optimizers()
-            predicted_tokens = torch.argmax(pred, dim=1).detach().cpu().numpy()
-            predicted_tokens = [predicted_tokens[i, :y_len[i]] for i in range(b)]
-            str_x = [LM_ind_to_str(self.dataset.charset, t, oov_symbol="") for t in predicted_tokens]
-
-        values = {
-            "nb_samples": b,
-            "str_y": batch_data["raw_labels"],
-            "str_x": str_x,
-            "loss": sum_loss.item(),
-            "loss_ce": loss_ce.item(),
-            "syn_max_lines": self.dataset.train_dataset.get_syn_max_lines() if self.params["dataset_params"]["config"]["synthetic_data"] else 0,
-        }
-
-        return values
-
-    def evaluate_batch(self, batch_data, metric_names):
-        x = batch_data["imgs"].to(self.device)
-        reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]]
-
-        max_chars = self.params["training_params"]["max_char_prediction"]
-
-        start_time = time.time()
-        with autocast(enabled=self.params["training_params"]["use_amp"]):
-            b = x.size(0)
-            reached_end = torch.zeros((b, ), dtype=torch.bool, device=self.device)
-            prediction_len = torch.zeros((b, ), dtype=torch.int, device=self.device)
-            predicted_tokens = torch.ones((b, 1), dtype=torch.long, device=self.device) * self.dataset.tokens["start"]
-            predicted_tokens_len = torch.ones((b, ), dtype=torch.int, device=self.device)
-
-            whole_output = list()
-            confidence_scores = list()
-            cache = None
-            hidden_predict = None
-            if b > 1:
-                features_list = list()
-                for i in range(b):
-                    pos = batch_data["imgs_position"]
-                    features_list.append(self.models["encoder"](x[i:i+1, :, pos[i][0][0]:pos[i][0][1], pos[i][1][0]:pos[i][1][1]]))
-                max_height = max([f.size(2) for f in features_list])
-                max_width = max([f.size(3) for f in features_list])
-                features = torch.zeros((b, features_list[0].size(1), max_height, max_width), device=self.device, dtype=features_list[0].dtype)
-                for i in range(b):
-                    features[i, :, :features_list[i].size(2), :features_list[i].size(3)] = features_list[i]
-            else:
-                features = self.models["encoder"](x)
-            features_size = features.size()
-            coverage_vector = torch.zeros((features.size(0), 1, features.size(2), features.size(3)), device=self.device)
-            pos_features = self.models["decoder"].features_updater.get_pos_features(features)
-            features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(2, 0, 1)
-            enhanced_features = pos_features
-            enhanced_features = torch.flatten(enhanced_features, start_dim=2, end_dim=3).permute(2, 0, 1)
-
-            for i in range(0, max_chars):
-                output, pred, hidden_predict, cache, weights = self.models["decoder"](features, enhanced_features, predicted_tokens, reduced_size, predicted_tokens_len, features_size, start=0, hidden_predict=hidden_predict, cache=cache, num_pred=1)
-                whole_output.append(output)
-                confidence_scores.append(torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values)
-                coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
-                predicted_tokens = torch.cat([predicted_tokens, torch.argmax(pred[:, :, -1], dim=1, keepdim=True)], dim=1)
-                reached_end = torch.logical_or(reached_end, torch.eq(predicted_tokens[:, -1], self.dataset.tokens["end"]))
-                predicted_tokens_len += 1
-
-                prediction_len[reached_end == False] = i + 1
-                if torch.all(reached_end):
-                    break
-
-            confidence_scores = torch.cat(confidence_scores, dim=1).cpu().detach().numpy()
-            predicted_tokens = predicted_tokens[:, 1:]
-            prediction_len[torch.eq(reached_end, False)] = max_chars - 1
-            predicted_tokens = [predicted_tokens[i, :prediction_len[i]] for i in range(b)]
-            confidence_scores = [confidence_scores[i, :prediction_len[i]].tolist() for i in range(b)]
-            str_x = [LM_ind_to_str(self.dataset.charset, t, oov_symbol="") for t in predicted_tokens]
-
-        process_time = time.time() - start_time
-
-        values = {
-            "nb_samples": b,
-            "str_y": batch_data["raw_labels"],
-            "str_x": str_x,
-            "confidence_score": confidence_scores,
-            "time": process_time,
-        }
-        return values
diff --git a/OCR/line_OCR/ctc/main_line_ctc.py b/OCR/line_OCR/ctc/main_line_ctc.py
deleted file mode 100644
index bb8fa1d6c277b8978bb251eb0259e17603e9e28e..0000000000000000000000000000000000000000
--- a/OCR/line_OCR/ctc/main_line_ctc.py
+++ /dev/null
@@ -1,173 +0,0 @@
-
-import os
-import sys
-from os.path import dirname
-DOSSIER_COURRANT = dirname(os.path.abspath(__file__))
-ROOT_FOLDER = dirname(dirname(dirname(DOSSIER_COURRANT)))
-sys.path.append(ROOT_FOLDER)
-from OCR.line_OCR.ctc.trainer_line_ctc import TrainerLineCTC
-from OCR.line_OCR.ctc.models_line_ctc import Decoder
-from dan.models import FCN_Encoder
-from torch.optim import Adam
-from basic.transforms import line_aug_config
-from basic.scheduler import exponential_dropout_scheduler, exponential_scheduler
-from OCR.ocr_dataset_manager import OCRDataset, OCRDatasetManager
-import torch.multiprocessing as mp
-import torch
-import numpy as np
-import random
-
-
-def train_and_test(rank, params):
-    torch.manual_seed(0)
-    torch.cuda.manual_seed(0)
-    np.random.seed(0)
-    random.seed(0)
-    torch.backends.cudnn.benchmark = False
-    torch.backends.cudnn.deterministic = True
-
-    params["training_params"]["ddp_rank"] = rank
-    model = TrainerLineCTC(params)
-    model.load_model()
-
-    # Model trains until max_time_training or max_nb_epochs is reached
-    model.train()
-
-    # load weights giving best CER on valid set
-    model.params["training_params"]["load_epoch"] = "best"
-    model.load_model()
-
-    # compute metrics on train, valid and test sets (in eval conditions)
-    metrics = ["cer", "wer", "time", ]
-    for dataset_name in params["dataset_params"]["datasets"].keys():
-        for set_name in ["test", "valid", "train", ]:
-            model.predict("{}-{}".format(dataset_name, set_name), [(dataset_name, set_name), ], metrics, output=True)
-
-
-def main():
-    dataset_name = "READ_2016"  # ["RIMES", "READ_2016"]
-    dataset_level = "syn_line"
-    params = {
-        "dataset_params": {
-            "dataset_manager": OCRDatasetManager,
-            "dataset_class": OCRDataset,
-            "datasets": {
-                dataset_name: "../../../Datasets/formatted/{}_{}".format(dataset_name, dataset_level),
-            },
-            "train": {
-                "name": "{}-train".format(dataset_name),
-                "datasets": [(dataset_name, "train"), ],
-            },
-            "valid": {
-                "{}-valid".format(dataset_name): [(dataset_name, "valid"), ],
-            },
-            "config": {
-                "load_in_memory": True,  # Load all images in CPU memory
-                "worker_per_gpu": 8,  # Num of parallel processes per gpu for data loading
-                "width_divisor": 8,  # Image width will be divided by 8
-                "height_divisor": 32,  # Image height will be divided by 32
-                "padding_value": 0,  # Image padding value
-                "padding_token": 1000,  # Label padding value (None: default value is chosen)
-                "padding_mode": "br",  # Padding at bottom and right
-                "charset_mode": "CTC",  # add blank token
-                "constraints": ["CTC_line", ],  # Padding for CTC requirements if necessary
-                "normalize": True,  # Normalize with mean and variance of training dataset
-                "padding": {
-                    "min_height": "max",  # Pad to reach max height of training samples
-                    "min_width": "max",  # Pad to reach max width of training samples
-                    "min_pad": None,
-                    "max_pad": None,
-                    "mode": "br",  # Padding at bottom and right
-                    "train_only": False,  # Add padding at training time and evaluation time
-                },
-                "preprocessings": [
-                    {
-                        "type": "to_RGB",
-                        # if grayscale image, produce RGB one (3 channels with same value) otherwise do nothing
-                    },
-                ],
-                # Augmentation techniques to use at training time
-                "augmentation": line_aug_config(0.9, 0.1),
-                #
-                "synthetic_data": {
-                    "mode": "line_hw_to_printed",
-                    "init_proba": 1,
-                    "end_proba": 1,
-                    "num_steps_proba": 1e5,
-                    "proba_scheduler_function": exponential_scheduler,
-                    "config": {
-                        "background_color_default": (255, 255, 255),
-                        "background_color_eps": 15,
-                        "text_color_default": (0, 0, 0),
-                        "text_color_eps": 15,
-                        "font_size_min": 30,
-                        "font_size_max": 50,
-                        "color_mode": "RGB",
-                        "padding_left_ratio_min": 0.02,
-                        "padding_left_ratio_max": 0.1,
-                        "padding_right_ratio_min": 0.02,
-                        "padding_right_ratio_max": 0.1,
-                        "padding_top_ratio_min": 0.02,
-                        "padding_top_ratio_max": 0.2,
-                        "padding_bottom_ratio_min": 0.02,
-                        "padding_bottom_ratio_max": 0.2,
-                    },
-                },
-            }
-        },
-
-        "model_params": {
-            # Model classes to use for each module
-            "models": {
-                "encoder": FCN_Encoder,
-                "decoder": Decoder,
-            },
-            "transfer_learning": None,
-            "input_channels": 3,  # 1 for grayscale images, 3 for RGB ones (or grayscale as RGB)
-            "enc_size": 256,
-            "dropout_scheduler": {
-                "function": exponential_dropout_scheduler,
-                "T": 5e4,
-            },
-            "dropout": 0.5,
-        },
-
-        "training_params": {
-            "output_folder": "FCN_read_2016_line_syn",  # folder names for logs and weigths
-            "max_nb_epochs": 10000,  # max number of epochs for the training
-            "max_training_time": 3600 * 24 * 1.9,  # max training time limit (in seconds)
-            "load_epoch": "last",  # ["best", "last"], to load weights from best epoch or last trained epoch
-            "interval_save_weights": None,  # None: keep best and last only
-            "use_ddp": False,  # Use DistributedDataParallel
-            "use_amp": True,  # Enable automatic mix-precision
-            "nb_gpu": torch.cuda.device_count(),
-            "batch_size": 16,  # mini-batch size per GPU
-            "optimizers": {
-                "all": {
-                    "class": Adam,
-                    "args": {
-                        "lr": 0.0001,
-                        "amsgrad": False,
-                    }
-                }
-            },
-            "lr_schedulers": None,  # Learning rate schedulers
-            "eval_on_valid": True,  # Whether to eval and logs metrics on validation set during training or not
-            "eval_on_valid_interval": 2,  # Interval (in epochs) to evaluate during training
-            "focus_metric": "cer",  # Metrics to focus on to determine best epoch
-            "expected_metric_value": "low",  # ["high", "low"] What is best for the focus metric value
-            "set_name_focus_metric": "{}-valid".format(dataset_name),  # Which dataset to focus on to select best weights
-            "train_metrics": ["loss_ctc", "cer", "wer"],  # Metrics name for training
-            "eval_metrics": ["loss_ctc", "cer", "wer"],  # Metrics name for evaluation on validation set during training
-            "force_cpu": False,  # True for debug purposes to run on cpu only
-        },
-    }
-
-    if params["training_params"]["use_ddp"] and not params["training_params"]["force_cpu"]:
-        mp.spawn(train_and_test, args=(params,), nprocs=params["training_params"]["nb_gpu"])
-    else:
-        train_and_test(0, params)
-
-
-if __name__ == "__main__":
-    main()
diff --git a/OCR/line_OCR/ctc/main_syn_line.py b/OCR/line_OCR/ctc/main_syn_line.py
deleted file mode 100644
index c16f9de5ab1fb38ce18486146f997ba60baedc61..0000000000000000000000000000000000000000
--- a/OCR/line_OCR/ctc/main_syn_line.py
+++ /dev/null
@@ -1,148 +0,0 @@
-
-import os
-import sys
-from os.path import dirname
-DOSSIER_COURRANT = dirname(os.path.abspath(__file__))
-ROOT_FOLDER = dirname(dirname(dirname(DOSSIER_COURRANT)))
-sys.path.append(ROOT_FOLDER)
-from OCR.line_OCR.ctc.trainer_line_ctc import TrainerLineCTC
-from OCR.line_OCR.ctc.models_line_ctc import Decoder
-from dan.models import FCN_Encoder
-from torch.optim import Adam
-from basic.transforms import line_aug_config
-from basic.scheduler import exponential_dropout_scheduler, exponential_scheduler
-from OCR.ocr_dataset_manager import OCRDataset, OCRDatasetManager
-import torch.multiprocessing as mp
-import torch
-import numpy as np
-import random
-
-
-def train_and_test(rank, params):
-    torch.manual_seed(0)
-    torch.cuda.manual_seed(0)
-    np.random.seed(0)
-    random.seed(0)
-    torch.backends.cudnn.benchmark = False
-    torch.backends.cudnn.deterministic = True
-
-    params["training_params"]["ddp_rank"] = rank
-    model = TrainerLineCTC(params)
-
-    model.generate_syn_line_dataset("READ_2016_syn_line")  # ["RIMES_syn_line", "READ_2016_syn_line"]
-
-
-def main():
-    dataset_name = "READ_2016"  # ["RIMES", "READ_2016"]
-    dataset_level = "page"
-    params = {
-        "dataset_params": {
-            "dataset_manager": OCRDatasetManager,
-            "dataset_class": OCRDataset,
-            "datasets": {
-                dataset_name: "../../../Datasets/formatted/{}_{}".format(dataset_name, dataset_level),
-            },
-            "train": {
-                "name": "{}-train".format(dataset_name),
-                "datasets": [(dataset_name, "train"), ],
-            },
-            "valid": {
-                "{}-valid".format(dataset_name): [(dataset_name, "valid"), ],
-            },
-            "config": {
-                "load_in_memory": False,  # Load all images in CPU memory
-                "worker_per_gpu": 4,
-                "width_divisor": 8,  # Image width will be divided by 8
-                "height_divisor": 32,  # Image height will be divided by 32
-                "padding_value": 0,  # Image padding value
-                "padding_token": 1000,  # Label padding value (None: default value is chosen)
-                "padding_mode": "br",  # Padding at bottom and right
-                "charset_mode": "CTC",  # add blank token
-                "constraints": [],  # Padding for CTC requirements if necessary
-                "normalize": True,  # Normalize with mean and variance of training dataset
-                "preprocessings": [],
-                # Augmentation techniques to use at training time
-                "augmentation": line_aug_config(0.9, 0.1),
-                #
-                "synthetic_data": {
-                    "mode": "line_hw_to_printed",
-                    "init_proba": 1,
-                    "end_proba": 1,
-                    "num_steps_proba": 1e5,
-                    "proba_scheduler_function": exponential_scheduler,
-                    "config": {
-                        "background_color_default": (255, 255, 255),
-                        "background_color_eps": 15,
-                        "text_color_default": (0, 0, 0),
-                        "text_color_eps": 15,
-                        "font_size_min": 30,
-                        "font_size_max": 50,
-                        "color_mode": "RGB",
-                        "padding_left_ratio_min": 0.02,
-                        "padding_left_ratio_max": 0.1,
-                        "padding_right_ratio_min": 0.02,
-                        "padding_right_ratio_max": 0.1,
-                        "padding_top_ratio_min": 0.02,
-                        "padding_top_ratio_max": 0.2,
-                        "padding_bottom_ratio_min": 0.02,
-                        "padding_bottom_ratio_max": 0.2,
-                    },
-                },
-            }
-        },
-
-        "model_params": {
-            # Model classes to use for each module
-            "models": {
-                "encoder": FCN_Encoder,
-                "decoder": Decoder,
-            },
-            "transfer_learning": None,
-            "input_channels": 3,  # 1 for grayscale images, 3 for RGB ones (or grayscale as RGB)
-            "enc_size": 256,
-            "dropout_scheduler": {
-                "function": exponential_dropout_scheduler,
-                "T": 5e4,
-            },
-            "dropout": 0.5,
-        },
-
-        "training_params": {
-            "output_folder": "FCN_Encoder_read_syn_line_all_pad_max_cursive",  # folder names for logs and weigths
-            "max_nb_epochs": 10000,  # max number of epochs for the training
-            "max_training_time": 3600 * 24 * 1.9,  # max training time limit (in seconds)
-            "load_epoch": "last",  # ["best", "last"], to load weights from best epoch or last trained epoch
-            "interval_save_weights": None,  # None: keep best and last only
-            "use_ddp": False,  # Use DistributedDataParallel
-            "use_amp": True,  # Enable automatic mix-precision
-            "nb_gpu": torch.cuda.device_count(),
-            "batch_size": 1,  # mini-batch size per GPU
-            "optimizers": {
-                "all": {
-                    "class": Adam,
-                    "args": {
-                        "lr": 0.0001,
-                        "amsgrad": False,
-                    }
-                }
-            },
-            "lr_schedulers": None,
-            "eval_on_valid": True,  # Whether to eval and logs metrics on validation set during training or not
-            "eval_on_valid_interval": 2,  # Interval (in epochs) to evaluate during training
-            "focus_metric": "cer",  # Metrics to focus on to determine best epoch
-            "expected_metric_value": "low",  # ["high", "low"] What is best for the focus metric value
-            "set_name_focus_metric": "{}-valid".format(dataset_name),
-            "train_metrics": ["loss_ctc", "cer", "wer"],  # Metrics name for training
-            "eval_metrics": ["loss_ctc", "cer", "wer"],  # Metrics name for evaluation on validation set during training
-            "force_cpu": False,  # True for debug purposes to run on cpu only
-        },
-    }
-
-    if params["training_params"]["use_ddp"] and not params["training_params"]["force_cpu"]:
-        mp.spawn(train_and_test, args=(params,), nprocs=params["training_params"]["nb_gpu"])
-    else:
-        train_and_test(0, params)
-
-
-if __name__ == "__main__":
-    main()
\ No newline at end of file
diff --git a/OCR/line_OCR/ctc/models_line_ctc.py b/OCR/line_OCR/ctc/models_line_ctc.py
deleted file mode 100644
index c13c072ebd10e810d41b8aab1e318c2170637ea0..0000000000000000000000000000000000000000
--- a/OCR/line_OCR/ctc/models_line_ctc.py
+++ /dev/null
@@ -1,19 +0,0 @@
-
-from torch.nn.functional import log_softmax
-from torch.nn import AdaptiveMaxPool2d, Conv1d
-from torch.nn import Module
-
-
-class Decoder(Module):
-    def __init__(self, params):
-        super(Decoder, self).__init__()
-
-        self.vocab_size = params["vocab_size"]
-
-        self.ada_pool = AdaptiveMaxPool2d((1, None))
-        self.end_conv = Conv1d(in_channels=params["enc_size"], out_channels=self.vocab_size+1, kernel_size=1)
-
-    def forward(self, x):
-        x = self.ada_pool(x).squeeze(2)
-        x = self.end_conv(x)
-        return log_softmax(x, dim=1)
diff --git a/OCR/line_OCR/ctc/trainer_line_ctc.py b/OCR/line_OCR/ctc/trainer_line_ctc.py
deleted file mode 100644
index eba12d81ecb4a417115121822d640be374cf185a..0000000000000000000000000000000000000000
--- a/OCR/line_OCR/ctc/trainer_line_ctc.py
+++ /dev/null
@@ -1,96 +0,0 @@
-
-from basic.metric_manager import MetricManager
-from OCR.ocr_manager import OCRManager
-from dan.ocr_utils import LM_ind_to_str
-import torch
-from torch.cuda.amp import autocast
-from torch.nn import CTCLoss
-import re
-import time
-
-
-class TrainerLineCTC(OCRManager):
-
-    def __init__(self, params):
-        super(TrainerLineCTC, self).__init__(params)
-
-    def train_batch(self, batch_data, metric_names):
-        """
-        Forward and backward pass for training
-        """
-        x = batch_data["imgs"].to(self.device)
-        y = batch_data["labels"].to(self.device)
-        x_reduced_len = [s[1] for s in batch_data["imgs_reduced_shape"]]
-        y_len = batch_data["labels_len"]
-
-        loss_ctc = CTCLoss(blank=self.dataset.tokens["blank"])
-        self.zero_optimizers()
-
-        with autocast(enabled=self.params["training_params"]["use_amp"]):
-            x = self.models["encoder"](x)
-            global_pred = self.models["decoder"](x)
-
-        loss = loss_ctc(global_pred.permute(2, 0, 1), y, x_reduced_len, y_len)
-
-        self.backward_loss(loss)
-
-        self.step_optimizers()
-        pred = torch.argmax(global_pred, dim=1).cpu().numpy()
-
-        values = {
-            "nb_samples": len(batch_data["raw_labels"]),
-            "loss_ctc": loss.item(),
-            "str_x": self.pred_to_str(pred, x_reduced_len),
-            "str_y": batch_data["raw_labels"]
-        }
-
-        return values
-
-    def evaluate_batch(self, batch_data, metric_names):
-        """
-        Forward pass only for validation and test
-        """
-        x = batch_data["imgs"].to(self.device)
-        y = batch_data["labels"].to(self.device)
-        x_reduced_len = [s[1] for s in batch_data["imgs_reduced_shape"]]
-        y_len = batch_data["labels_len"]
-
-        loss_ctc = CTCLoss(blank=self.dataset.tokens["blank"])
-
-        start_time = time.time()
-        with autocast(enabled=self.params["training_params"]["use_amp"]):
-            x = self.models["encoder"](x)
-            global_pred = self.models["decoder"](x)
-
-        loss = loss_ctc(global_pred.permute(2, 0, 1), y, x_reduced_len, y_len)
-        pred = torch.argmax(global_pred, dim=1).cpu().numpy()
-        str_x = self.pred_to_str(pred, x_reduced_len)
-
-        process_time =time.time() - start_time
-
-        values = {
-            "nb_samples": len(batch_data["raw_labels"]),
-            "loss_ctc": loss.item(),
-            "str_x": str_x,
-            "str_y": batch_data["raw_labels"],
-            "time": process_time
-        }
-        return values
-
-    def ctc_remove_successives_identical_ind(self, ind):
-        res = []
-        for i in ind:
-            if res and res[-1] == i:
-                continue
-            res.append(i)
-        return res
-
-    def pred_to_str(self, pred, pred_len):
-        """
-        convert prediction tokens to string
-        """
-        ind_x = [pred[i][:pred_len[i]] for i in range(pred.shape[0])]
-        ind_x = [self.ctc_remove_successives_identical_ind(t) for t in ind_x]
-        str_x = [LM_ind_to_str(self.dataset.charset, t, oov_symbol="") for t in ind_x]
-        str_x = [re.sub("( )+", ' ', t).strip(" ") for t in str_x]
-        return str_x
diff --git a/OCR/ocr_dataset_manager.py b/OCR/ocr_dataset_manager.py
deleted file mode 100644
index ed7c087cba8a3b5ef7d435a8c73eba01d35ff40f..0000000000000000000000000000000000000000
--- a/OCR/ocr_dataset_manager.py
+++ /dev/null
@@ -1,1036 +0,0 @@
-#  Copyright Université de Rouen Normandie (1), INSA Rouen (2),
-#  tutelles du laboratoire LITIS (1 et 2)
-#  contributors :
-#  - Denis Coquenet
-#
-#
-#  This software is a computer program written in Python  whose purpose is to
-#  provide public implementation of deep learning works, in pytorch.
-#
-#  This software is governed by the CeCILL-C license under French law and
-#  abiding by the rules of distribution of free software.  You can  use,
-#  modify and/ or redistribute the software under the terms of the CeCILL-C
-#  license as circulated by CEA, CNRS and INRIA at the following URL
-#  "http://www.cecill.info".
-#
-#  As a counterpart to the access to the source code and  rights to copy,
-#  modify and redistribute granted by the license, users are provided only
-#  with a limited warranty  and the software's author,  the holder of the
-#  economic rights,  and the successive licensors  have only  limited
-#  liability.
-#
-#  In this respect, the user's attention is drawn to the risks associated
-#  with loading,  using,  modifying and/or developing or reproducing the
-#  software by the user in light of its specific status of free software,
-#  that may mean  that it is complicated to manipulate,  and  that  also
-#  therefore means  that it is reserved for developers  and  experienced
-#  professionals having in-depth computer knowledge. Users are therefore
-#  encouraged to load and test the software's suitability as regards their
-#  requirements in conditions enabling the security of their systems and/or
-#  data to be ensured and,  more generally, to use and operate it in the
-#  same conditions as regards security.
-#
-#  The fact that you are presently reading this means that you have had
-#  knowledge of the CeCILL-C license and that you accept its terms.
-import numpy.random
-
-from basic.generic_dataset_manager import DatasetManager, GenericDataset
-from basic.utils import pad_images, pad_image_width_right, resize_max, pad_image_width_random, pad_sequences_1D, pad_image_height_random, pad_image_width_left, pad_image
-from basic.utils import randint, rand, rand_uniform
-from basic.generic_dataset_manager import apply_preprocessing
-from Datasets.dataset_formatters.read2016_formatter import SEM_MATCHING_TOKENS as READ_MATCHING_TOKENS
-from Datasets.dataset_formatters.rimes_formatter import order_text_regions as order_text_regions_rimes
-from Datasets.dataset_formatters.rimes_formatter import SEM_MATCHING_TOKENS as RIMES_MATCHING_TOKENS
-from Datasets.dataset_formatters.rimes_formatter import SEM_MATCHING_TOKENS_STR as RIMES_MATCHING_TOKENS_STR
-from dan.ocr_utils import LM_str_to_ind
-import random
-import cv2
-import os
-import copy
-import pickle
-import numpy as np
-import torch
-import matplotlib
-from PIL import Image, ImageDraw, ImageFont
-from basic.transforms import RandomRotation, apply_transform, Tightening
-from fontTools.ttLib import TTFont
-from fontTools.unicode import Unicode
-
-
-class OCRDatasetManager(DatasetManager):
-    """
-    Specific class to handle OCR/HTR tasks
-    """
-
-    def __init__(self, params):
-        super(OCRDatasetManager, self).__init__(params)
-
-        self.charset = params["charset"] if "charset" in params else self.get_merged_charsets()
-
-        if "synthetic_data" in self.params["config"] and self.params["config"]["synthetic_data"] and "config" in self.params["config"]["synthetic_data"]:
-            self.char_only_set = self.charset.copy()
-            for token_dict in [RIMES_MATCHING_TOKENS, READ_MATCHING_TOKENS]:
-                for key in token_dict:
-                    if key in self.char_only_set:
-                        self.char_only_set.remove(key)
-                    if token_dict[key] in self.char_only_set:
-                        self.char_only_set.remove(token_dict[key])
-            for token in ["\n", ]:
-                if token in self.char_only_set:
-                    self.char_only_set.remove(token)
-            self.params["config"]["synthetic_data"]["config"]["valid_fonts"] = get_valid_fonts(self.char_only_set)
-
-        if "new_tokens" in params:
-            self.charset = sorted(list(set(self.charset).union(set(params["new_tokens"]))))
-
-        self.tokens = {
-            "pad": params["config"]["padding_token"],
-        }
-        if self.params["config"]["charset_mode"].lower() == "ctc":
-            self.tokens["blank"] = len(self.charset)
-            self.tokens["pad"] = self.tokens["pad"] if self.tokens["pad"] else len(self.charset) + 1
-            self.params["config"]["padding_token"] = self.tokens["pad"]
-        elif self.params["config"]["charset_mode"] == "seq2seq":
-            self.tokens["end"] = len(self.charset)
-            self.tokens["start"] = len(self.charset) + 1
-            self.tokens["pad"] = self.tokens["pad"] if self.tokens["pad"] else len(self.charset) + 2
-            self.params["config"]["padding_token"] = self.tokens["pad"]
-
-    def get_merged_charsets(self):
-        """
-        Merge the charset of the different datasets used
-        """
-        datasets = self.params["datasets"]
-        charset = set()
-        for key in datasets.keys():
-            with open(os.path.join(datasets[key], "labels.pkl"), "rb") as f:
-                info = pickle.load(f)
-                charset = charset.union(set(info["charset"]))
-        if "\n" in charset and "remove_linebreaks" in self.params["config"]["constraints"]:
-            charset.remove("\n")
-        if "" in charset:
-            charset.remove("")
-        return sorted(list(charset))
-
-    def apply_specific_treatment_after_dataset_loading(self, dataset):
-        dataset.charset = self.charset
-        dataset.tokens = self.tokens
-        dataset.convert_labels()
-        if "READ_2016" in dataset.name and "augmentation" in dataset.params["config"] and dataset.params["config"]["augmentation"]:
-            dataset.params["config"]["augmentation"]["fill_value"] = tuple([int(i) for i in dataset.mean])
-        if "padding" in dataset.params["config"] and dataset.params["config"]["padding"]["min_height"] == "max":
-            dataset.params["config"]["padding"]["min_height"] = max([s["img"].shape[0] for s in self.train_dataset.samples])
-        if "padding" in dataset.params["config"] and dataset.params["config"]["padding"]["min_width"] == "max":
-            dataset.params["config"]["padding"]["min_width"] = max([s["img"].shape[1] for s in self.train_dataset.samples])
-
-
-class OCRDataset(GenericDataset):
-    """
-    Specific class to handle OCR/HTR datasets
-    """
-
-    def __init__(self, params, set_name, custom_name, paths_and_sets):
-        super(OCRDataset, self).__init__(params, set_name, custom_name, paths_and_sets)
-        self.charset = None
-        self.tokens = None
-        self.reduce_dims_factor = np.array([params["config"]["height_divisor"], params["config"]["width_divisor"], 1])
-        self.collate_function = OCRCollateFunction
-        self.synthetic_id = 0
-
-    def __getitem__(self, idx):
-        sample = copy.deepcopy(self.samples[idx])
-
-        if not self.load_in_memory:
-            sample["img"] = self.get_sample_img(idx)
-            sample = apply_preprocessing(sample, self.params["config"]["preprocessings"])
-
-        if "synthetic_data" in self.params["config"] and self.params["config"]["synthetic_data"] and self.set_name == "train":
-            sample = self.generate_synthetic_data(sample)
-
-        # Data augmentation
-        sample["img"], sample["applied_da"] = self.apply_data_augmentation(sample["img"])
-
-        if "max_size" in self.params["config"] and self.params["config"]["max_size"]:
-            max_ratio = max(sample["img"].shape[0] / self.params["config"]["max_size"]["max_height"], sample["img"].shape[1] / self.params["config"]["max_size"]["max_width"])
-            if max_ratio > 1:
-                new_h, new_w = int(np.ceil(sample["img"].shape[0] / max_ratio)), int(np.ceil(sample["img"].shape[1] / max_ratio))
-                sample["img"] = cv2.resize(sample["img"], (new_w, new_h))
-
-        # Normalization if requested
-        if "normalize" in self.params["config"] and self.params["config"]["normalize"]:
-            sample["img"] = (sample["img"] - self.mean) / self.std
-
-        sample["img_shape"] = sample["img"].shape
-        sample["img_reduced_shape"] = np.ceil(sample["img_shape"] / self.reduce_dims_factor).astype(int)
-
-        # Padding to handle CTC requirements
-        if self.set_name == "train":
-            max_label_len = 0
-            height = 1
-            ctc_padding = False
-            if "CTC_line" in self.params["config"]["constraints"]:
-                max_label_len = sample["label_len"]
-                ctc_padding = True
-            if "CTC_va" in self.params["config"]["constraints"]:
-                max_label_len = max(sample["line_label_len"])
-                ctc_padding = True
-            if "CTC_pg" in self.params["config"]["constraints"]:
-                max_label_len = sample["label_len"]
-                height = max(sample["img_reduced_shape"][0], 1)
-                ctc_padding = True
-            if ctc_padding and 2 * max_label_len + 1 > sample["img_reduced_shape"][1]*height:
-                sample["img"] = pad_image_width_right(sample["img"], int(np.ceil((2 * max_label_len + 1) / height) * self.reduce_dims_factor[1]), self.padding_value)
-                sample["img_shape"] = sample["img"].shape
-                sample["img_reduced_shape"] = np.ceil(sample["img_shape"] / self.reduce_dims_factor).astype(int)
-            sample["img_reduced_shape"] = [max(1, t) for t in sample["img_reduced_shape"]]
-
-        sample["img_position"] = [[0, sample["img_shape"][0]], [0, sample["img_shape"][1]]]
-        # Padding constraints to handle model needs
-        if "padding" in self.params["config"] and self.params["config"]["padding"]:
-            if self.set_name == "train" or not self.params["config"]["padding"]["train_only"]:
-                min_pad = self.params["config"]["padding"]["min_pad"]
-                max_pad = self.params["config"]["padding"]["max_pad"]
-                pad_width = randint(min_pad, max_pad) if min_pad is not None and max_pad is not None else None
-                pad_height = randint(min_pad, max_pad) if min_pad is not None and max_pad is not None else None
-
-                sample["img"], sample["img_position"] = pad_image(sample["img"], padding_value=self.padding_value,
-                                          new_width=self.params["config"]["padding"]["min_width"],
-                                          new_height=self.params["config"]["padding"]["min_height"],
-                                          pad_width=pad_width,
-                                          pad_height=pad_height,
-                                          padding_mode=self.params["config"]["padding"]["mode"],
-                                          return_position=True)
-        sample["img_reduced_position"] = [np.ceil(p / factor).astype(int) for p, factor in zip(sample["img_position"], self.reduce_dims_factor[:2])]
-        return sample
-
-
-    def get_charset(self):
-        charset = set()
-        for i in range(len(self.samples)):
-            charset = charset.union(set(self.samples[i]["label"]))
-        return charset
-
-    def convert_labels(self):
-        """
-        Label str to token at character level
-        """
-        for i in range(len(self.samples)):
-            self.samples[i] = self.convert_sample_labels(self.samples[i])
-
-    def convert_sample_labels(self, sample):
-        label = sample["label"]
-        line_labels = label.split("\n")
-        if "remove_linebreaks" in self.params["config"]["constraints"]:
-            full_label = label.replace("\n", " ").replace("  ", " ")
-            word_labels = full_label.split(" ")
-        else:
-            full_label = label
-            word_labels = label.replace("\n", " ").replace("  ", " ").split(" ")
-
-        sample["label"] = full_label
-        sample["token_label"] = LM_str_to_ind(self.charset, full_label)
-        if "add_eot" in self.params["config"]["constraints"]:
-            sample["token_label"].append(self.tokens["end"])
-        sample["label_len"] = len(sample["token_label"])
-        if "add_sot" in self.params["config"]["constraints"]:
-            sample["token_label"].insert(0, self.tokens["start"])
-
-        sample["line_label"] = line_labels
-        sample["token_line_label"] = [LM_str_to_ind(self.charset, l) for l in line_labels]
-        sample["line_label_len"] = [len(l) for l in line_labels]
-        sample["nb_lines"] = len(line_labels)
-
-        sample["word_label"] = word_labels
-        sample["token_word_label"] = [LM_str_to_ind(self.charset, l) for l in word_labels]
-        sample["word_label_len"] = [len(l) for l in word_labels]
-        sample["nb_words"] = len(word_labels)
-        return sample
-
-    def generate_synthetic_data(self, sample):
-        config = self.params["config"]["synthetic_data"]
-
-        if not (config["init_proba"] == config["end_proba"] == 1):
-            nb_samples = self.training_info["step"] * self.params["batch_size"]
-            if config["start_scheduler_at_max_line"]:
-                max_step = config["num_steps_proba"]
-                current_step = max(0, min(nb_samples-config["curr_step"]*(config["max_nb_lines"]-config["min_nb_lines"]), max_step))
-                proba = config["init_proba"] if self.get_syn_max_lines() < config["max_nb_lines"] else \
-                    config["proba_scheduler_function"](config["init_proba"], config["end_proba"], current_step, max_step)
-            else:
-                proba = config["proba_scheduler_function"](config["init_proba"], config["end_proba"],
-                                                       min(nb_samples, config["num_steps_proba"]),
-                                                       config["num_steps_proba"])
-            if rand() > proba:
-                return sample
-
-        if "mode" in config and config["mode"] == "line_hw_to_printed":
-            sample["img"] = self.generate_typed_text_line_image(sample["label"])
-            return sample
-
-        return self.generate_synthetic_page_sample()
-
-    def get_syn_max_lines(self):
-        config = self.params["config"]["synthetic_data"]
-        if config["curriculum"]:
-            nb_samples = self.training_info["step"]*self.params["batch_size"]
-            max_nb_lines = min(config["max_nb_lines"], (nb_samples-config["curr_start"]) // config["curr_step"]+1)
-            return max(config["min_nb_lines"], max_nb_lines)
-        return config["max_nb_lines"]
-
-    def generate_synthetic_page_sample(self):
-        config = self.params["config"]["synthetic_data"]
-        max_nb_lines_per_page = self.get_syn_max_lines()
-        crop = config["crop_curriculum"] and max_nb_lines_per_page < config["max_nb_lines"]
-        sample = {
-            "name": "synthetic_data_{}".format(self.synthetic_id),
-            "path": None
-        }
-        self.synthetic_id += 1
-        nb_pages = 2 if "double" in config["dataset_level"] else 1
-        background_sample = copy.deepcopy(self.samples[randint(0, len(self))])
-        pages = list()
-        backgrounds = list()
-
-        h, w, c = background_sample["img"].shape
-        page_width = w // 2 if nb_pages == 2 else w
-        for i in range(nb_pages):
-            nb_lines_per_page = randint(config["min_nb_lines"], max_nb_lines_per_page+1)
-            background = np.ones((h, page_width, c), dtype=background_sample["img"].dtype) * 255
-            if i == 0 and nb_pages == 2:
-                background[:, -2:, :] = 0
-            backgrounds.append(background)
-            if "READ_2016" in self.params["datasets"].keys():
-                side = background_sample["pages_label"][i]["side"]
-                coords = {
-                    "left": int(0.15 * page_width) if side == "left" else int(0.05 * page_width),
-                    "right": int(0.95 * page_width) if side == "left" else int(0.85 * page_width),
-                    "top": int(0.05 * h),
-                    "bottom": int(0.85 * h),
-                }
-                pages.append(self.generate_synthetic_read2016_page(background, coords, side=side, crop=crop,
-                                                               nb_lines=nb_lines_per_page))
-            elif "RIMES" in self.params["datasets"].keys():
-                pages.append(self.generate_synthetic_rimes_page(background, nb_lines=nb_lines_per_page, crop=crop))
-            else:
-                raise NotImplementedError
-
-        if nb_pages == 1:
-            sample["img"] = pages[0][0]
-            sample["label_raw"] = pages[0][1]["raw"]
-            sample["label_begin"] = pages[0][1]["begin"]
-            sample["label_sem"] = pages[0][1]["sem"]
-            sample["label"] = pages[0][1]
-            sample["nb_cols"] = pages[0][2]
-        else:
-            if pages[0][0].shape[0] != pages[1][0].shape[0]:
-                max_height = max(pages[0][0].shape[0], pages[1][0].shape[0])
-                backgrounds[0] = backgrounds[0][:max_height]
-                backgrounds[0][:pages[0][0].shape[0]] = pages[0][0]
-                backgrounds[1] = backgrounds[1][:max_height]
-                backgrounds[1][:pages[1][0].shape[0]] = pages[1][0]
-                pages[0][0] = backgrounds[0]
-                pages[1][0] = backgrounds[1]
-            sample["label_raw"] = pages[0][1]["raw"] + "\n" + pages[1][1]["raw"]
-            sample["label_begin"] = pages[0][1]["begin"] + pages[1][1]["begin"]
-            sample["label_sem"] = pages[0][1]["sem"] + pages[1][1]["sem"]
-            sample["img"] = np.concatenate([pages[0][0], pages[1][0]], axis=1)
-            sample["nb_cols"] = pages[0][2] + pages[1][2]
-        sample["label"] = sample["label_raw"]
-        if "â“‘" in self.charset:
-            sample["label"] = sample["label_begin"]
-        if "â’·" in self.charset:
-            sample["label"] = sample["label_sem"]
-        sample["unchanged_label"] = sample["label"]
-        sample = self.convert_sample_labels(sample)
-        return sample
-
-    def generate_synthetic_rimes_page(self, background, nb_lines=20, crop=False):
-        max_nb_lines = self.get_syn_max_lines()
-        def larger_lines(label):
-            lines = label.split("\n")
-            new_lines = list()
-            while len(lines) > 0:
-                if len(lines) == 1:
-                    new_lines.append(lines[0])
-                    del lines[0]
-                elif len(lines[0]) + len(lines[1]) < max_len:
-                    new_lines.append("{} {}".format(lines[0], lines[1]))
-                    del lines[1]
-                    del lines[0]
-                else:
-                    new_lines.append(lines[0])
-                    del lines[0]
-            return "\n".join(new_lines)
-        config = self.params["config"]["synthetic_data"]
-        max_len = 100
-        matching_tokens = RIMES_MATCHING_TOKENS
-        matching_tokens_str = RIMES_MATCHING_TOKENS_STR
-        h, w, c = background.shape
-        num_lines = list()
-        for s in self.samples:
-            l = sum([len(p["label"].split("\n")) for p in s["paragraphs_label"]])
-            num_lines.append(l)
-        stats = self.stat_sem_rimes()
-        ordered_modes = ['Corps de texte', 'PS/PJ', 'Ouverture', 'Date, Lieu', 'Coordonnées Expéditeur', 'Coordonnées Destinataire', ]
-        object_ref = ['Objet', 'Reference']
-        random.shuffle(object_ref)
-        ordered_modes = ordered_modes[:3] + object_ref + ordered_modes[3:]
-        kept_modes = list()
-        for mode in ordered_modes:
-            if rand_uniform(0, 1) < stats[mode]:
-                kept_modes.append(mode)
-
-        paragraphs = dict()
-        for mode in kept_modes:
-            paragraphs[mode] = self.get_paragraph_rimes(mode=mode, mix=True)
-            # proba to merge multiple body textual contents
-            if mode == "Corps de texte" and rand_uniform(0, 1) < 0.2:
-                nb_lines = min(nb_lines+10, max_nb_lines) if max_nb_lines < 30 else nb_lines+10
-                concat_line = randint(0, 2) == 0
-                if concat_line:
-                    paragraphs[mode]["label"] = larger_lines(paragraphs[mode]["label"])
-                while (len(paragraphs[mode]["label"].split("\n")) <= 30):
-                    body2 = self.get_paragraph_rimes(mode=mode, mix=True)
-                    paragraphs[mode]["label"] += "\n" + larger_lines(body2["label"]) if concat_line else body2["label"]
-                    paragraphs[mode]["label"] = "\n".join(paragraphs[mode]["label"].split("\n")[:40])
-        # proba to set whole text region to uppercase
-        if rand_uniform(0, 1) < 0.1 and "Corps de texte" in paragraphs:
-            paragraphs["Corps de texte"]["label"] = paragraphs["Corps de texte"]["label"].upper().replace("È", "E").replace("Ë", "E").replace("Û", "U").replace("Ù", "U").replace("Î", "I").replace("Ï", "I").replace("Â", "A").replace("Œ", "OE")
-        # proba to duplicate a line and place it randomly elsewhere, in a body region
-        if rand_uniform(0, 1) < 0.1 and "Corps de texte" in paragraphs:
-            labels = paragraphs["Corps de texte"]["label"].split("\n")
-            duplicated_label = labels[randint(0, len(labels))]
-            labels.insert(randint(0, len(labels)), duplicated_label)
-            paragraphs["Corps de texte"]["label"] = "\n".join(labels)
-        # proba to merge successive lines to have longer text lines in body
-        if rand_uniform(0, 1) < 0.1 and "Corps de texte" in paragraphs:
-            paragraphs["Corps de texte"]["label"] = larger_lines(paragraphs["Corps de texte"]["label"])
-        for mode in paragraphs.keys():
-            line_labels = paragraphs[mode]["label"].split("\n")
-            if len(line_labels) == 0:
-                print("ERROR")
-            paragraphs[mode]["lines"] = list()
-            for line_label in line_labels:
-                if len(line_label) > 100:
-                    for chunk in [line_label[i:i + max_len] for i in range(0, len(line_label), max_len)]:
-                        paragraphs[mode]["lines"].append(chunk)
-                else:
-                    paragraphs[mode]["lines"].append(line_label)
-        page_labels = {
-            "raw": "",
-            "begin": "",
-            "sem": ""
-        }
-        top_limit = 0
-        bottom_limit = h
-        max_bottom_crop = 0
-        min_top_crop = h
-        has_opening = has_object = has_reference = False
-        top_opening = top_object = top_reference = 0
-        right_opening = right_object = right_reference = 0
-        has_reference = False
-        date_on_top = False
-        date_alone = False
-        for mode in kept_modes:
-            pg = paragraphs[mode]
-            if len(pg["lines"]) > nb_lines:
-                pg["lines"] = pg["lines"][:nb_lines]
-            nb_lines -= len(pg["lines"])
-            pg_image = self.generate_typed_text_paragraph_image(pg["lines"], padding_value=255, max_pad_left_ratio=1, same_font_size=True)
-            # proba to remove some interline spacing
-            if rand_uniform(0, 1) < 0.1:
-                pg_image = apply_transform(pg_image, Tightening(color=255, remove_proba=0.75))
-            # proba to rotate text region
-            if rand_uniform(0, 1) < 0.1:
-                pg_image = apply_transform(pg_image, RandomRotation(degrees=10, expand=True, fill=255))
-            pg["added"] = True
-            if mode == 'Corps de texte':
-                pg_image = resize_max(pg_image, max_height=int(0.5*h), max_width=w)
-                img_h, img_w = pg_image.shape[:2]
-                min_top = int(0.4*h)
-                max_top = int(0.9*h - img_h)
-                top = randint(min_top, max_top + 1)
-                left = randint(0, int(w - img_w) + 1)
-                bottom_body = top + img_h
-                top_body = top
-                bottom_limit = min(top, bottom_limit)
-            elif mode == "PS/PJ":
-                pg_image = resize_max(pg_image, max_height=int(0.03*h), max_width=int(0.9*w))
-                img_h, img_w = pg_image.shape[:2]
-                min_top = bottom_body
-                max_top = int(min(h - img_h, bottom_body + 0.15*h))
-                top = randint(min_top, max_top + 1)
-                left = randint(0, int(w - img_w) + 1)
-                bottom_limit = min(top, bottom_limit)
-            elif mode == "Ouverture":
-                pg_image = resize_max(pg_image, max_height=int(0.03 * h), max_width=int(0.9 * w))
-                img_h, img_w = pg_image.shape[:2]
-                min_top = int(top_body - 0.05 * h)
-                max_top = top_body - img_h
-                top = randint(min_top, max_top + 1)
-                left = randint(0, min(int(0.15*w), int(w - img_w)) + 1)
-                has_opening = True
-                top_opening = top
-                right_opening = left + img_w
-                bottom_limit = min(top, bottom_limit)
-            elif mode == "Objet":
-                pg_image = resize_max(pg_image, max_height=int(0.03 * h), max_width=int(0.9 * w))
-                img_h, img_w = pg_image.shape[:2]
-                max_top = top_reference - img_h if has_reference else top_opening - img_h if has_opening else top_body - img_h
-                min_top = int(max_top - 0.05 * h)
-                top = randint(min_top, max_top + 1)
-                left = randint(0, min(int(0.15*w), int(w - img_w)) + 1)
-                has_object = True
-                top_object = top
-                right_object = left + img_w
-                bottom_limit = min(top, bottom_limit)
-            elif mode == "Reference":
-                pg_image = resize_max(pg_image, max_height=int(0.03 * h), max_width=int(0.9 * w))
-                img_h, img_w = pg_image.shape[:2]
-                max_top = top_object - img_h if has_object else top_opening - img_h if has_opening else top_body - img_h
-                min_top = int(max_top - 0.05 * h)
-                top = randint(min_top, max_top + 1)
-                left = randint(0, min(int(0.15*w), int(w - img_w)) + 1)
-                has_reference = True
-                top_reference = top
-                right_reference = left + img_w
-                bottom_limit = min(top, bottom_limit)
-            elif mode == 'Date, Lieu':
-                pg_image = resize_max(pg_image, max_height=int(0.03 * h), max_width=int(0.45 * w))
-                img_h, img_w = pg_image.shape[:2]
-                if h - max_bottom_crop - 10 > img_h and randint(0, 10) == 0:
-                    top = randint(max_bottom_crop, h)
-                    left = randint(0, w-img_w)
-                else:
-                    min_top = top_body - img_h
-                    max_top = top_body - img_h
-                    min_left = 0
-                    # Check if there is anough place to put the date at the right side of opening, reference or object
-                    if object_ref == ['Objet', 'Reference']:
-                        have = [has_opening, has_object, has_reference]
-                        rights = [right_opening, right_object, right_reference]
-                        tops = [top_opening, top_object, top_reference]
-                    else:
-                        have = [has_opening, has_reference, has_object]
-                        rights = [right_opening, right_reference, right_object]
-                        tops = [top_opening, top_reference, top_object]
-                    for right_r, top_r, has_r in zip(rights, tops, have):
-                        if has_r:
-                            if right_r + img_w >= 0.95*w:
-                                max_top = min(top_r - img_h, max_top)
-                                min_left = 0
-                            else:
-                                min_left = max(min_left, right_r+0.05*w)
-                                min_top = top_r - img_h if min_top == top_body - img_h else min_top
-                    if min_left != 0 and randint(0, 5) == 0:
-                        min_left = 0
-                        for right_r, top_r, has_r in zip(rights, tops, have):
-                            if has_r:
-                                max_top = min(max_top, top_r-img_h)
-
-                    max_left = max(min_left, w - img_w)
-
-                    # No placement found at right-side of opening, reference or object
-                    if min_left == 0:
-                        # place on the top
-                        if randint(0, 2) == 0:
-                            min_top = 0
-                            max_top = int(min(0.05*h, max_top))
-                            date_on_top = True
-                        # place just before object/reference/opening
-                        else:
-                            min_top = int(max(0, max_top - 0.05*h))
-                            date_alone = True
-                            max_left = min(max_left, int(0.1*w))
-
-                    min_top = min(min_top, max_top)
-                    top = randint(min_top, max_top + 1)
-                    left = randint(int(min_left), max_left + 1)
-                    if date_on_top:
-                        top_limit = max(top_limit, top + img_h)
-                    else:
-                        bottom_limit = min(top, bottom_limit)
-                    date_right = left + img_w
-                    date_bottom = top + img_h
-            elif mode == "Coordonnées Expéditeur":
-                max_height = min(0.25*h, bottom_limit-top_limit)
-                if max_height <= 0:
-                    pg["added"] = False
-                    print("ko", bottom_limit, top_limit)
-                    break
-                pg_image = resize_max(pg_image, max_height=int(max_height), max_width=int(0.45 * w))
-                img_h, img_w = pg_image.shape[:2]
-                top = randint(top_limit, bottom_limit-img_h+1)
-                left = randint(0, int(0.5*w-img_w)+1)
-            elif mode == "Coordonnées Destinataire":
-                if h - max_bottom_crop - 10 > 0.2*h and randint(0, 10) == 0:
-                    pg_image = resize_max(pg_image, max_height=int(0.2*h), max_width=int(0.45 * w))
-                    img_h, img_w = pg_image.shape[:2]
-                    top = randint(max_bottom_crop, h)
-                    left = randint(0, w-img_w)
-                else:
-                    max_height = min(0.25*h, bottom_limit-top_limit)
-                    if max_height <= 0:
-                        pg["added"] = False
-                        print("ko", bottom_limit, top_limit)
-                        break
-                    pg_image = resize_max(pg_image, max_height=int(max_height), max_width=int(0.45 * w))
-                    img_h, img_w = pg_image.shape[:2]
-                    if date_alone and w - date_right - img_w > 11:
-                        top = randint(0, date_bottom-img_h+1)
-                        left = randint(max(int(0.5*w), date_right+10), w-img_w)
-                    else:
-                        top = randint(top_limit, bottom_limit-img_h+1)
-                        left = randint(int(0.5*w), int(w - img_w)+1)
-
-            bottom = top+img_h
-            right = left+img_w
-            min_top_crop = min(top, min_top_crop)
-            max_bottom_crop = max(bottom, max_bottom_crop)
-            try:
-                background[top:bottom, left:right, ...] = pg_image
-            except:
-                pg["added"] = False
-                nb_lines = 0
-            pg["coords"] = {
-                "top": top,
-                "bottom": bottom,
-                "right": right,
-                "left": left
-            }
-
-            if nb_lines <= 0:
-                break
-        sorted_pg = order_text_regions_rimes(paragraphs.values())
-        for pg in sorted_pg:
-            if "added" in pg.keys() and pg["added"]:
-                pg_label = "\n".join(pg["lines"])
-                mode = pg["type"]
-                begin_token = matching_tokens_str[mode]
-                end_token = matching_tokens[begin_token]
-                page_labels["raw"] += pg_label
-                page_labels["begin"] += begin_token + pg_label
-                page_labels["sem"] += begin_token + pg_label + end_token
-        if crop:
-            if min_top_crop > max_bottom_crop:
-                print("KO - min > MAX")
-            elif min_top_crop > h:
-                print("KO - min > h")
-            else:
-                background = background[min_top_crop:max_bottom_crop]
-        return [background, page_labels, 1]
-
-    def stat_sem_rimes(self):
-        try:
-            return self.rimes_sem_stats
-        except:
-            stats = dict()
-            for sample in self.samples:
-                for pg in sample["paragraphs_label"]:
-                    mode = pg["type"]
-                    if mode == 'Coordonnées Expéditeur':
-                        if len(pg["label"]) < 50 and "\n" not in pg["label"]:
-                            mode = "Reference"
-                    if mode not in stats.keys():
-                        stats[mode] = 0
-                    else:
-                        stats[mode] += 1
-            for key in stats:
-                stats[key] = max(0.10, stats[key]/len(self.samples))
-            self.rimes_sem_stats = stats
-            return stats
-
-    def get_paragraph_rimes(self, mode="Corps de texte", mix=False):
-        while True:
-            sample = self.samples[randint(0, len(self))]
-            random.shuffle(sample["paragraphs_label"])
-            for pg in sample["paragraphs_label"]:
-                pg_mode = pg["type"]
-                if pg_mode == 'Coordonnées Expéditeur':
-                    if len(pg["label"]) < 50 and "\n" not in pg["label"]:
-                        pg_mode = "Reference"
-                if mode == pg_mode:
-                    if mode == "Corps de texte" and mix:
-                        return self.get_mix_paragraph_rimes(mode, min(5, len(pg["label"].split("\n"))))
-                    else:
-                        return pg
-
-    def get_mix_paragraph_rimes(self, mode="Corps de texte", num_lines=10):
-        res = list()
-        while len(res) != num_lines:
-            sample = self.samples[randint(0, len(self))]
-            random.shuffle(sample["paragraphs_label"])
-            for pg in sample["paragraphs_label"]:
-                pg_mode = pg["type"]
-                if pg_mode == 'Coordonnées Expéditeur':
-                    if len(pg["label"]) < 50 and "\n" not in pg["label"]:
-                        pg_mode = "Reference"
-                if mode == pg_mode:
-                    lines = pg["label"].split("\n")
-                    res.append(lines[randint(0, len(lines))])
-                    break
-        return {
-            "label": "\n".join(res),
-            "type": mode,
-        }
-
-    def generate_synthetic_read2016_page(self, background, coords, side="left", nb_lines=20, crop=False):
-        config = self.params["config"]["synthetic_data"]
-        two_column = False
-        matching_token = READ_MATCHING_TOKENS
-        page_labels = {
-            "raw": "",
-            "begin": "â“Ÿ",
-            "sem": "â“Ÿ",
-        }
-        area_top = coords["top"]
-        area_left = coords["left"]
-        area_right = coords["right"]
-        area_bottom = coords["bottom"]
-        num_page_text_label = str(randint(0, 1000))
-        num_page_img = self.generate_typed_text_line_image(num_page_text_label)
-
-        if side == "left":
-            background[area_top:area_top+num_page_img.shape[0], area_left:area_left+num_page_img.shape[1]] = num_page_img
-        else:
-            background[area_top:area_top + num_page_img.shape[0], area_right-num_page_img.shape[1]:area_right] = num_page_img
-        for key in ["sem", "begin"]:
-            page_labels[key] += "ⓝ"
-        for key in page_labels.keys():
-            page_labels[key] += num_page_text_label
-        page_labels["sem"] += matching_token["ⓝ"]
-        nb_lines -= 1
-        area_top = area_top + num_page_img.shape[0] + randint(1, 20)
-        ratio_ann = rand_uniform(0.6, 0.7)
-        while nb_lines > 0:
-            nb_body_lines = randint(1, nb_lines+1)
-            max_ann_lines = min(nb_body_lines, nb_lines-nb_body_lines)
-            body_labels = list()
-            body_imgs = list()
-            while nb_body_lines > 0:
-                current_nb_lines = 1
-                label, img = self.get_printed_line_read_2016("body")
-
-                nb_body_lines -= current_nb_lines
-                body_labels.append(label)
-                body_imgs.append(img)
-            nb_ann_lines = randint(0, min(6, max_ann_lines+1))
-            ann_labels = list()
-            ann_imgs = list()
-            while nb_ann_lines > 0:
-                current_nb_lines = 1
-                label, img = self.get_printed_line_read_2016("annotation")
-                nb_ann_lines -= current_nb_lines
-                ann_labels.append(label)
-                ann_imgs.append(img)
-            max_width_body = int(np.floor(ratio_ann*(area_right-area_left)))
-            max_width_ann = area_right-area_left-max_width_body
-            for img_list, max_width in zip([body_imgs, ann_imgs], [max_width_body, max_width_ann]):
-                for i in range(len(img_list)):
-                    if img_list[i].shape[1] > max_width:
-                        ratio = max_width/img_list[i].shape[1]
-                        new_h = int(np.floor(ratio*img_list[i].shape[0]))
-                        new_w = int(np.floor(ratio*img_list[i].shape[1]))
-                        img_list[i] = cv2.resize(img_list[i], (new_w, new_h), interpolation=cv2.INTER_LINEAR)
-            body_top = area_top
-            body_height = 0
-            i_body = 0
-            for (label, img) in zip(body_labels, body_imgs):
-                remaining_height = area_bottom - body_top
-                if img.shape[0] > remaining_height:
-                    nb_lines = 0
-                    break
-                background[body_top:body_top+img.shape[0], area_left+max_width_ann:area_left+max_width_ann+img.shape[1]] = img
-                body_height += img.shape[0]
-                body_top += img.shape[0]
-                nb_lines -= 1
-                i_body += 1
-
-            ann_height = int(np.sum([img.shape[0] for img in ann_imgs]))
-            ann_top = area_top + randint(0, body_height-ann_height+1) if ann_height < body_height else area_top
-            largest_ann = max([a.shape[1] for a in ann_imgs]) if len(ann_imgs) > 0 else max_width_ann
-            pad_ann = randint(0, max_width_ann-largest_ann+1) if max_width_ann > largest_ann else 0
-
-            ann_label_blocks = [list(), ]
-            i_ann = 0
-            ann_height = 0
-            for (label, img) in zip(ann_labels, ann_imgs):
-                remaining_height = body_top - ann_top
-                if img.shape[0] > remaining_height:
-                    break
-                background[ann_top:ann_top+img.shape[0], area_left+pad_ann:area_left+pad_ann+img.shape[1]] = img
-                ann_height += img.shape[0]
-                ann_top += img.shape[0]
-                nb_lines -= 1
-                two_column = True
-                ann_label_blocks[-1].append(ann_labels[i_ann])
-                i_ann += 1
-                if randint(0, 10) == 0:
-                    ann_label_blocks.append(list())
-                    ann_top += randint(0, max(15, body_top-ann_top-20))
-
-            area_top = area_top + max(ann_height, body_height) + randint(25, 100)
-
-            ann_full_labels = {
-                "raw": "",
-                "begin": "",
-                "sem": "",
-            }
-            for ann_label_block in ann_label_blocks:
-                if len(ann_label_block) > 0:
-                    for key in ["sem", "begin"]:
-                        ann_full_labels[key] += "ⓐ"
-                    ann_full_labels["raw"] += "\n"
-                    for key in ann_full_labels.keys():
-                        ann_full_labels[key] += "\n".join(ann_label_block)
-                    ann_full_labels["sem"] += matching_token["ⓐ"]
-
-            body_full_labels = {
-                "raw": "",
-                "begin": "",
-                "sem": "",
-            }
-            if i_body > 0:
-                for key in ["sem", "begin"]:
-                    body_full_labels[key] += "â“‘"
-                body_full_labels["raw"] += "\n"
-                for key in body_full_labels.keys():
-                    body_full_labels[key] += "\n".join(body_labels[:i_body])
-                body_full_labels["sem"] += matching_token["â“‘"]
-
-            section_labels = dict()
-            for key in ann_full_labels.keys():
-                section_labels[key] = ann_full_labels[key] + body_full_labels[key]
-            for key in section_labels.keys():
-                if section_labels[key] != "":
-                    if key in ["sem", "begin"]:
-                        section_labels[key] = "â“¢" + section_labels[key]
-                    if key == "sem":
-                        section_labels[key] = section_labels[key] + matching_token["â“¢"]
-            for key in page_labels.keys():
-                page_labels[key] += section_labels[key]
-
-        if crop:
-            background = background[:area_top]
-
-        page_labels["sem"] += matching_token["â“Ÿ"]
-
-        for key in page_labels.keys():
-            page_labels[key] = page_labels[key].strip()
-
-        return [background, page_labels, 2 if two_column else 1]
-
-    def get_n_consecutive_lines_read_2016(self, n=1, mode="body"):
-        while True:
-            sample = self.samples[randint(0, len(self))]
-            paragraphs = list()
-            for page in sample["pages_label"]:
-                paragraphs.extend(page["paragraphs"])
-                random.shuffle(paragraphs)
-                for pg in paragraphs:
-                    if ((mode == "body" and pg["mode"] == "body") or
-                        (mode == "ann" and pg["mode"] == "annotation")) and len(pg["lines"]) >= n:
-                        line_idx = randint(0, len(pg["lines"])-n+1)
-                        lines = pg["lines"][line_idx:line_idx+n]
-                        label = "\n".join([l["text"] for l in lines])
-                        top = min([l["top"] for l in lines])
-                        bottom = max([l["bottom"] for l in lines])
-                        left = min([l["left"] for l in lines])
-                        right = max([l["right"] for l in lines])
-                        img = sample["img"][top:bottom, left:right]
-                        return label, img
-
-    def get_printed_line_read_2016(self, mode="body"):
-        while True:
-            sample = self.samples[randint(0, len(self))]
-            for page in sample["pages_label"]:
-                paragraphs = list()
-                paragraphs.extend(page["paragraphs"])
-                random.shuffle(paragraphs)
-                for pg in paragraphs:
-                    random.shuffle(pg["lines"])
-                    for line in pg["lines"]:
-                        if (mode == "body" and len(line["text"]) > 5) or (mode == "annotation" and len(line["text"]) < 15 and not line["text"].isdigit()):
-                            label = line["text"]
-                            img = self.generate_typed_text_line_image(label)
-                            return label, img
-
-    def generate_typed_text_line_image(self, text):
-        return generate_typed_text_line_image(text, self.params["config"]["synthetic_data"]["config"])
-
-    def generate_typed_text_paragraph_image(self, texts, padding_value=255, max_pad_left_ratio=0.1, same_font_size=False):
-        config = self.params["config"]["synthetic_data"]["config"]
-        if same_font_size:
-            images = list()
-            txt_color = config["text_color_default"]
-            bg_color = config["background_color_default"]
-            font_size = randint(config["font_size_min"], config["font_size_max"] + 1)
-            for text in texts:
-                font_path = config["valid_fonts"][randint(0, len(config["valid_fonts"]))]
-                fnt = ImageFont.truetype(font_path, font_size)
-                text_width, text_height = fnt.getsize(text)
-                padding_top = int(rand_uniform(config["padding_top_ratio_min"], config["padding_top_ratio_max"]) * text_height)
-                padding_bottom = int(rand_uniform(config["padding_bottom_ratio_min"], config["padding_bottom_ratio_max"]) * text_height)
-                padding_left = int(rand_uniform(config["padding_left_ratio_min"], config["padding_left_ratio_max"]) * text_width)
-                padding_right = int(rand_uniform(config["padding_right_ratio_min"], config["padding_right_ratio_max"]) * text_width)
-                padding = [padding_top, padding_bottom, padding_left, padding_right]
-                images.append(generate_typed_text_line_image_from_params(text, fnt, bg_color, txt_color, config["color_mode"], padding))
-        else:
-            images = [self.generate_typed_text_line_image(t) for t in texts]
-
-        max_width = max([img.shape[1] for img in images])
-
-        padded_images = [pad_image_width_random(img, max_width, padding_value=padding_value, max_pad_left_ratio=max_pad_left_ratio) for img in images]
-        return np.concatenate(padded_images, axis=0)
-
-
-
-class OCRCollateFunction:
-    """
-    Merge samples data to mini-batch data for OCR task
-    """
-
-    def __init__(self, config):
-        self.img_padding_value = float(config["padding_value"])
-        self.label_padding_value = config["padding_token"]
-        self.config = config
-
-    def __call__(self, batch_data):
-        names = [batch_data[i]["name"] for i in range(len(batch_data))]
-        ids = [batch_data[i]["name"].split("/")[-1].split(".")[0] for i in range(len(batch_data))]
-        applied_da = [batch_data[i]["applied_da"] for i in range(len(batch_data))]
-
-        labels = [batch_data[i]["token_label"] for i in range(len(batch_data))]
-        labels = pad_sequences_1D(labels, padding_value=self.label_padding_value)
-        labels = torch.tensor(labels).long()
-        reverse_labels = [[batch_data[i]["token_label"][0], ] + batch_data[i]["token_label"][-2:0:-1] + [batch_data[i]["token_label"][-1], ] for i in range(len(batch_data))]
-        reverse_labels = pad_sequences_1D(reverse_labels, padding_value=self.label_padding_value)
-        reverse_labels = torch.tensor(reverse_labels).long()
-        labels_len = [batch_data[i]["label_len"] for i in range(len(batch_data))]
-
-        raw_labels = [batch_data[i]["label"] for i in range(len(batch_data))]
-        unchanged_labels = [batch_data[i]["unchanged_label"] for i in range(len(batch_data))]
-
-        nb_cols = [batch_data[i]["nb_cols"] for i in range(len(batch_data))]
-        nb_lines = [batch_data[i]["nb_lines"] for i in range(len(batch_data))]
-        line_raw = [batch_data[i]["line_label"] for i in range(len(batch_data))]
-        line_token = [batch_data[i]["token_line_label"] for i in range(len(batch_data))]
-        pad_line_token = list()
-        line_len = [batch_data[i]["line_label_len"] for i in range(len(batch_data))]
-        for i in range(max(nb_lines)):
-            current_lines = [line_token[j][i] if i < nb_lines[j] else [self.label_padding_value] for j in range(len(batch_data))]
-            pad_line_token.append(torch.tensor(pad_sequences_1D(current_lines, padding_value=self.label_padding_value)).long())
-            for j in range(len(batch_data)):
-                if i >= nb_lines[j]:
-                    line_len[j].append(0)
-        line_len = [i for i in zip(*line_len)]
-
-        nb_words = [batch_data[i]["nb_words"] for i in range(len(batch_data))]
-        word_raw = [batch_data[i]["word_label"] for i in range(len(batch_data))]
-        word_token = [batch_data[i]["token_word_label"] for i in range(len(batch_data))]
-        pad_word_token = list()
-        word_len = [batch_data[i]["word_label_len"] for i in range(len(batch_data))]
-        for i in range(max(nb_words)):
-            current_words = [word_token[j][i] if i < nb_words[j] else [self.label_padding_value] for j in range(len(batch_data))]
-            pad_word_token.append(torch.tensor(pad_sequences_1D(current_words, padding_value=self.label_padding_value)).long())
-            for j in range(len(batch_data)):
-                if i >= nb_words[j]:
-                    word_len[j].append(0)
-        word_len = [i for i in zip(*word_len)]
-
-        padding_mode = self.config["padding_mode"] if "padding_mode" in self.config else "br"
-        imgs = [batch_data[i]["img"] for i in range(len(batch_data))]
-        imgs_shape = [batch_data[i]["img_shape"] for i in range(len(batch_data))]
-        imgs_reduced_shape = [batch_data[i]["img_reduced_shape"] for i in range(len(batch_data))]
-        imgs_position = [batch_data[i]["img_position"] for i in range(len(batch_data))]
-        imgs_reduced_position= [batch_data[i]["img_reduced_position"] for i in range(len(batch_data))]
-        imgs = pad_images(imgs, padding_value=self.img_padding_value, padding_mode=padding_mode)
-        imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2)
-        formatted_batch_data = {
-            "names": names,
-            "ids": ids,
-            "nb_lines": nb_lines,
-            "nb_cols": nb_cols,
-            "labels": labels,
-            "reverse_labels": reverse_labels,
-            "raw_labels": raw_labels,
-            "unchanged_labels": unchanged_labels,
-            "labels_len": labels_len,
-            "imgs": imgs,
-            "imgs_shape": imgs_shape,
-            "imgs_reduced_shape": imgs_reduced_shape,
-            "imgs_position": imgs_position,
-            "imgs_reduced_position": imgs_reduced_position,
-            "line_raw": line_raw,
-            "line_labels": pad_line_token,
-            "line_labels_len": line_len,
-            "nb_words": nb_words,
-            "word_raw": word_raw,
-            "word_labels": pad_word_token,
-            "word_labels_len": word_len,
-            "applied_da": applied_da
-        }
-
-        return formatted_batch_data
-
-
-def generate_typed_text_line_image(text, config, bg_color=(255, 255, 255), txt_color=(0, 0, 0)):
-    if text == "":
-        text = " "
-    if "text_color_default" in config:
-        txt_color = config["text_color_default"]
-    if "background_color_default" in config:
-        bg_color = config["background_color_default"]
-
-    font_path = config["valid_fonts"][randint(0, len(config["valid_fonts"]))]
-    font_size = randint(config["font_size_min"], config["font_size_max"]+1)
-    fnt = ImageFont.truetype(font_path, font_size)
-
-    text_width, text_height = fnt.getsize(text)
-    padding_top = int(rand_uniform(config["padding_top_ratio_min"], config["padding_top_ratio_max"])*text_height)
-    padding_bottom = int(rand_uniform(config["padding_bottom_ratio_min"], config["padding_bottom_ratio_max"])*text_height)
-    padding_left = int(rand_uniform(config["padding_left_ratio_min"], config["padding_left_ratio_max"])*text_width)
-    padding_right = int(rand_uniform(config["padding_right_ratio_min"], config["padding_right_ratio_max"])*text_width)
-    padding = [padding_top, padding_bottom, padding_left, padding_right]
-    return generate_typed_text_line_image_from_params(text, fnt, bg_color, txt_color, config["color_mode"], padding)
-
-
-def generate_typed_text_line_image_from_params(text, font, bg_color, txt_color, color_mode, padding):
-    padding_top, padding_bottom, padding_left, padding_right = padding
-    text_width, text_height = font.getsize(text)
-    img_height = padding_top + padding_bottom + text_height
-    img_width = padding_left + padding_right + text_width
-    img = Image.new(color_mode, (img_width, img_height), color=bg_color)
-    d = ImageDraw.Draw(img)
-    d.text((padding_left, padding_bottom), text, font=font, fill=txt_color, spacing=0)
-    return np.array(img)
-
-
-def get_valid_fonts(alphabet=None):
-    valid_fonts = list()
-    for fold_detail in os.walk("../../../Fonts"):
-        if fold_detail[2]:
-            for font_name in fold_detail[2]:
-                if ".ttf" not in font_name:
-                    continue
-                font_path = os.path.join(fold_detail[0], font_name)
-                to_add = True
-                if alphabet is not None:
-                    for char in alphabet:
-                        if not char_in_font(char, font_path):
-                            to_add = False
-                            break
-                    if to_add:
-                        valid_fonts.append(font_path)
-                else:
-                    valid_fonts.append(font_path)
-    return valid_fonts
-
-
-def char_in_font(unicode_char, font_path):
-    with TTFont(font_path) as font:
-        for cmap in font['cmap'].tables:
-            if cmap.isUnicode():
-                if ord(unicode_char) in cmap.cmap:
-                    return True
-    return False
diff --git a/OCR/ocr_manager.py b/OCR/ocr_manager.py
deleted file mode 100644
index e34b36c0a0793868c5a8f590b20510db72bc0681..0000000000000000000000000000000000000000
--- a/OCR/ocr_manager.py
+++ /dev/null
@@ -1,67 +0,0 @@
-from basic.generic_training_manager import GenericTrainingManager
-import os
-from PIL import Image
-import pickle
-
-
-class OCRManager(GenericTrainingManager):
-    def __init__(self, params):
-        super(OCRManager, self).__init__(params)
-        self.params["model_params"]["vocab_size"] = len(self.dataset.charset)
-
-    def generate_syn_line_dataset(self, name):
-        """
-        Generate synthetic line dataset from currently loaded dataset
-        """
-        dataset_name = list(self.params['dataset_params']["datasets"].keys())[0]
-        path = os.path.join(os.path.dirname(self.params['dataset_params']["datasets"][dataset_name]), name)
-        os.makedirs(path, exist_ok=True)
-        charset = set()
-        dataset = None
-        gt = {
-            "train": dict(),
-            "valid": dict(),
-            "test": dict()
-        }
-        for set_name in ["train", "valid", "test"]:
-            set_path = os.path.join(path, set_name)
-            os.makedirs(set_path, exist_ok=True)
-            if set_name == "train":
-                dataset = self.dataset.train_dataset
-            elif set_name == "valid":
-                dataset = self.dataset.valid_datasets["{}-valid".format(dataset_name)]
-            elif set_name == "test":
-                self.dataset.generate_test_loader("{}-test".format(dataset_name), [(dataset_name, "test"), ])
-                dataset = self.dataset.test_datasets["{}-test".format(dataset_name)]
-
-            samples = list()
-            for sample in dataset.samples:
-                for line_label in sample["label"].split("\n"):
-                    for chunk in [line_label[i:i+100] for i in range(0, len(line_label), 100)]:
-                        charset = charset.union(set(chunk))
-                        if len(chunk) > 0:
-                            samples.append({
-                                "path": sample["path"],
-                                "label": chunk,
-                                "nb_cols": 1,
-                            })
-
-            for i, sample in enumerate(samples):
-                ext = sample['path'].split(".")[-1]
-                img_name = "{}_{}.{}".format(set_name, i, ext)
-                img_path = os.path.join(set_path, img_name)
-
-                img = dataset.generate_typed_text_line_image(sample["label"])
-                Image.fromarray(img).save(img_path)
-                gt[set_name][img_name] = {
-                    "text": sample["label"],
-                    "nb_cols": sample["nb_cols"] if "nb_cols" in sample else 1
-                }
-                if "line_label" in sample:
-                    gt[set_name][img_name]["lines"] = sample["line_label"]
-
-        with open(os.path.join(path, "labels.pkl"), "wb") as f:
-            pickle.dump({
-                "ground_truth": gt,
-                "charset": sorted(list(charset)),
-            }, f)
\ No newline at end of file
diff --git a/README.md b/README.md
index 28d0dc9201d28151f4b494df0fc6685b991d0479..9b29f2040998f2a13c945eb8686490eab4d37069 100644
--- a/README.md
+++ b/README.md
@@ -2,10 +2,10 @@
 
 This repository is a public implementation of the paper: "DAN: a Segmentation-free Document Attention Network for Handwritten Document Recognition".
 
-![Prediction visualization](visual.png)
+![Prediction visualization](images/visual.png)
 
 The model uses a character-level attention to handle slanted lines:
-![Prediction visualization on slanted lines](visual_slanted_lines.png)
+![Prediction visualization on slanted lines](images/visual_slanted_lines.png)
 
 The paper is available at https://arxiv.org/abs/2203.12273.
 
diff --git a/basic/generic_dataset_manager.py b/basic/generic_dataset_manager.py
deleted file mode 100644
index dd272d152884b4fc2d27f6467ad915f099f77acc..0000000000000000000000000000000000000000
--- a/basic/generic_dataset_manager.py
+++ /dev/null
@@ -1,390 +0,0 @@
-import torch
-import random
-from torch.utils.data import Dataset, DataLoader
-from torch.utils.data.distributed import DistributedSampler
-from basic.transforms import apply_data_augmentation
-from Datasets.dataset_formatters.utils_dataset import natural_sort
-import os
-import numpy as np
-import pickle
-from PIL import Image
-import cv2
-
-
-class DatasetManager:
-
-    def __init__(self, params):
-        self.params = params
-        self.dataset_class = params["dataset_class"]
-        self.img_padding_value = params["config"]["padding_value"]
-
-        self.my_collate_function = None
-
-        self.train_dataset = None
-        self.valid_datasets = dict()
-        self.test_datasets = dict()
-
-        self.train_loader = None
-        self.valid_loaders = dict()
-        self.test_loaders = dict()
-
-        self.train_sampler = None
-        self.valid_samplers = dict()
-        self.test_samplers = dict()
-
-        self.generator = torch.Generator()
-        self.generator.manual_seed(0)
-
-        self.batch_size = {
-            "train": self.params["batch_size"],
-            "valid": self.params["valid_batch_size"] if "valid_batch_size" in self.params else self.params["batch_size"],
-            "test": self.params["test_batch_size"] if "test_batch_size" in self.params else 1,
-        }
-
-    def apply_specific_treatment_after_dataset_loading(self, dataset):
-        raise NotImplementedError
-
-    def load_datasets(self):
-        """
-        Load training and validation datasets
-        """
-        self.train_dataset = self.dataset_class(self.params, "train", self.params["train"]["name"], self.get_paths_and_sets(self.params["train"]["datasets"]))
-        self.params["config"]["mean"], self.params["config"]["std"] = self.train_dataset.compute_std_mean()
-
-        self.my_collate_function = self.train_dataset.collate_function(self.params["config"])
-        self.apply_specific_treatment_after_dataset_loading(self.train_dataset)
-
-        for custom_name in self.params["valid"].keys():
-            self.valid_datasets[custom_name] = self.dataset_class(self.params, "valid", custom_name, self.get_paths_and_sets(self.params["valid"][custom_name]))
-            self.apply_specific_treatment_after_dataset_loading(self.valid_datasets[custom_name])
-
-    def load_ddp_samplers(self):
-        """
-        Load training and validation data samplers
-        """
-        if self.params["use_ddp"]:
-            self.train_sampler = DistributedSampler(self.train_dataset, num_replicas=self.params["num_gpu"], rank=self.params["ddp_rank"], shuffle=True)
-            for custom_name in self.valid_datasets.keys():
-                self.valid_samplers[custom_name] = DistributedSampler(self.valid_datasets[custom_name], num_replicas=self.params["num_gpu"], rank=self.params["ddp_rank"], shuffle=False)
-        else:
-            for custom_name in self.valid_datasets.keys():
-                self.valid_samplers[custom_name] = None
-
-    def load_dataloaders(self):
-        """
-        Load training and validation data loaders
-        """
-        self.train_loader = DataLoader(self.train_dataset,
-                                       batch_size=self.batch_size["train"],
-                                       shuffle=True if self.train_sampler is None else False,
-                                       drop_last=False,
-                                       batch_sampler=self.train_sampler,
-                                       sampler=self.train_sampler,
-                                       num_workers=self.params["num_gpu"]*self.params["worker_per_gpu"],
-                                       pin_memory=True,
-                                       collate_fn=self.my_collate_function,
-                                       worker_init_fn=self.seed_worker,
-                                       generator=self.generator)
-
-        for key in self.valid_datasets.keys():
-            self.valid_loaders[key] = DataLoader(self.valid_datasets[key],
-                                                 batch_size=self.batch_size["valid"],
-                                                 sampler=self.valid_samplers[key],
-                                                 batch_sampler=self.valid_samplers[key],
-                                                 shuffle=False,
-                                                 num_workers=self.params["num_gpu"]*self.params["worker_per_gpu"],
-                                                 pin_memory=True,
-                                                 drop_last=False,
-                                                 collate_fn=self.my_collate_function,
-                                                 worker_init_fn=self.seed_worker,
-                                                 generator=self.generator)
-
-    @staticmethod
-    def seed_worker(worker_id):
-        worker_seed = torch.initial_seed() % 2 ** 32
-        np.random.seed(worker_seed)
-        random.seed(worker_seed)
-
-    def generate_test_loader(self, custom_name, sets_list):
-        """
-        Load test dataset, data sampler and data loader
-        """
-        if custom_name in self.test_loaders.keys():
-            return
-        paths_and_sets = list()
-        for set_info in sets_list:
-            paths_and_sets.append({
-                "path": self.params["datasets"][set_info[0]],
-                "set_name": set_info[1]
-            })
-        self.test_datasets[custom_name] = self.dataset_class(self.params, "test", custom_name, paths_and_sets)
-        self.apply_specific_treatment_after_dataset_loading(self.test_datasets[custom_name])
-        if self.params["use_ddp"]:
-            self.test_samplers[custom_name] = DistributedSampler(self.test_datasets[custom_name], num_replicas=self.params["num_gpu"], rank=self.params["ddp_rank"], shuffle=False)
-        else:
-            self.test_samplers[custom_name] = None
-        self.test_loaders[custom_name] = DataLoader(self.test_datasets[custom_name],
-                                                    batch_size=self.batch_size["test"],
-                                                    sampler=self.test_samplers[custom_name],
-                                                    shuffle=False,
-                                                    num_workers=self.params["num_gpu"]*self.params["worker_per_gpu"],
-                                                    pin_memory=True,
-                                                    drop_last=False,
-                                                    collate_fn=self.my_collate_function,
-                                                    worker_init_fn=self.seed_worker,
-                                                    generator=self.generator)
-
-    def remove_test_dataset(self, custom_name):
-        del self.test_datasets[custom_name]
-        del self.test_samplers[custom_name]
-        del self.test_loaders[custom_name]
-
-    def remove_valid_dataset(self, custom_name):
-        del self.valid_datasets[custom_name]
-        del self.valid_samplers[custom_name]
-        del self.valid_loaders[custom_name]
-
-    def remove_train_dataset(self):
-        self.train_dataset = None
-        self.train_sampler = None
-        self.train_loader = None
-
-    def remove_all_datasets(self):
-        self.remove_train_dataset()
-        for name in list(self.valid_datasets.keys()):
-            self.remove_valid_dataset(name)
-        for name in list(self.test_datasets.keys()):
-            self.remove_test_dataset(name)
-
-    def get_paths_and_sets(self, dataset_names_folds):
-        paths_and_sets = list()
-        for dataset_name, fold in dataset_names_folds:
-            path = self.params["datasets"][dataset_name]
-            paths_and_sets.append({
-                "path": path,
-                "set_name": fold
-            })
-        return paths_and_sets
-
-
-class GenericDataset(Dataset):
-    """
-    Main class to handle dataset loading
-    """
-
-    def __init__(self, params, set_name, custom_name, paths_and_sets):
-        self.params = params
-        self.name = custom_name
-        self.set_name = set_name
-        self.mean = np.array(params["config"]["mean"]) if "mean" in params["config"].keys() else None
-        self.std = np.array(params["config"]["std"]) if "std" in params["config"].keys() else None
-
-        self.load_in_memory = self.params["config"]["load_in_memory"] if "load_in_memory" in self.params["config"] else True
-
-        self.samples = self.load_samples(paths_and_sets, load_in_memory=self.load_in_memory)
-
-        if self.load_in_memory:
-            self.apply_preprocessing(params["config"]["preprocessings"])
-
-        self.padding_value = params["config"]["padding_value"]
-        if self.padding_value == "mean":
-            if self.mean is None:
-                _, _ = self.compute_std_mean()
-            self.padding_value = self.mean
-            self.params["config"]["padding_value"] = self.padding_value
-
-        self.curriculum_config = None
-        self.training_info = None
-
-    def __len__(self):
-        return len(self.samples)
-
-    @staticmethod
-    def load_image(path):
-        with Image.open(path) as pil_img:
-            img = np.array(pil_img)
-            ## grayscale images
-            if len(img.shape) == 2:
-                img = np.expand_dims(img, axis=2)
-        return img
-
-    @staticmethod
-    def load_samples(paths_and_sets, load_in_memory=True):
-        """
-        Load images and labels
-        """
-        samples = list()
-        for path_and_set in paths_and_sets:
-            path = path_and_set["path"]
-            set_name = path_and_set["set_name"]
-            with open(os.path.join(path, "labels.pkl"), "rb") as f:
-                info = pickle.load(f)
-                gt = info["ground_truth"][set_name]
-                for filename in natural_sort(gt.keys()):
-                    name = os.path.join(os.path.basename(path), set_name, filename)
-                    full_path = os.path.join(path, set_name, filename)
-                    if isinstance(gt[filename], dict) and "text" in gt[filename]:
-                        label = gt[filename]["text"]
-                    else:
-                        label = gt[filename]
-                    samples.append({
-                        "name": name,
-                        "label": label,
-                        "unchanged_label": label,
-                        "path": full_path,
-                        "nb_cols": 1 if "nb_cols" not in gt[filename] else gt[filename]["nb_cols"]
-                    })
-                    if load_in_memory:
-                        samples[-1]["img"] = GenericDataset.load_image(full_path)
-                    if type(gt[filename]) is dict:
-                        if "lines" in gt[filename].keys():
-                            samples[-1]["raw_line_seg_label"] = gt[filename]["lines"]
-                        if "paragraphs" in gt[filename].keys():
-                            samples[-1]["paragraphs_label"] = gt[filename]["paragraphs"]
-                        if "pages" in gt[filename].keys():
-                            samples[-1]["pages_label"] = gt[filename]["pages"]
-        return samples
-
-    def apply_preprocessing(self, preprocessings):
-        for i in range(len(self.samples)):
-            self.samples[i] = apply_preprocessing(self.samples[i], preprocessings)
-
-    def compute_std_mean(self):
-        """
-        Compute cumulated variance and mean of whole dataset
-        """
-        if self.mean is not None and self.std is not None:
-            return self.mean, self.std
-        if not self.load_in_memory:
-            sample = self.samples[0].copy()
-            sample["img"] = self.get_sample_img(0)
-            img = apply_preprocessing(sample, self.params["config"]["preprocessings"])["img"]
-        else:
-            img = self.get_sample_img(0)
-        _, _, c = img.shape
-        sum = np.zeros((c,))
-        nb_pixels = 0
-
-        for i in range(len(self.samples)):
-            if not self.load_in_memory:
-                sample = self.samples[i].copy()
-                sample["img"] = self.get_sample_img(i)
-                img = apply_preprocessing(sample, self.params["config"]["preprocessings"])["img"]
-            else:
-                img = self.get_sample_img(i)
-            sum += np.sum(img, axis=(0, 1))
-            nb_pixels += np.prod(img.shape[:2])
-        mean = sum / nb_pixels
-        diff = np.zeros((c,))
-        for i in range(len(self.samples)):
-            if not self.load_in_memory:
-                sample = self.samples[i].copy()
-                sample["img"] = self.get_sample_img(i)
-                img = apply_preprocessing(sample, self.params["config"]["preprocessings"])["img"]
-            else:
-                img = self.get_sample_img(i)
-            diff += [np.sum((img[:, :, k] - mean[k]) ** 2) for k in range(c)]
-        std = np.sqrt(diff / nb_pixels)
-
-        self.mean = mean
-        self.std = std
-        return mean, std
-
-    def apply_data_augmentation(self, img):
-        """
-        Apply data augmentation strategy on the input image
-        """
-        augs = [self.params["config"][key] if key in self.params["config"].keys() else None for key in ["augmentation", "valid_augmentation", "test_augmentation"]]
-        for aug, set_name in zip(augs, ["train", "valid", "test"]):
-            if aug and self.set_name == set_name:
-                return apply_data_augmentation(img, aug)
-        return img, list()
-
-    def get_sample_img(self, i):
-        """
-        Get image by index
-        """
-        if self.load_in_memory:
-            return self.samples[i]["img"]
-        else:
-            return GenericDataset.load_image(self.samples[i]["path"])
-
-    def denormalize(self, img):
-        """
-        Get original image, before normalization
-        """
-        return img * self.std + self.mean
-
-
-def apply_preprocessing(sample, preprocessings):
-    """
-    Apply preprocessings on each sample
-    """
-    resize_ratio = [1, 1]
-    img = sample["img"]
-    for preprocessing in preprocessings:
-
-        if preprocessing["type"] == "dpi":
-            ratio = preprocessing["target"] / preprocessing["source"]
-            temp_img = img
-            h, w, c = temp_img.shape
-            temp_img = cv2.resize(temp_img, (int(np.ceil(w * ratio)), int(np.ceil(h * ratio))))
-            if len(temp_img.shape) == 2:
-                temp_img = np.expand_dims(temp_img, axis=2)
-            img = temp_img
-
-            resize_ratio = [ratio, ratio]
-
-        if preprocessing["type"] == "to_grayscaled":
-            temp_img = img
-            h, w, c = temp_img.shape
-            if c == 3:
-                img = np.expand_dims(
-                    0.2125 * temp_img[:, :, 0] + 0.7154 * temp_img[:, :, 1] + 0.0721 * temp_img[:, :, 2],
-                    axis=2).astype(np.uint8)
-
-        if preprocessing["type"] == "to_RGB":
-            temp_img = img
-            h, w, c = temp_img.shape
-            if c == 1:
-                img = np.concatenate([temp_img, temp_img, temp_img], axis=2)
-
-        if preprocessing["type"] == "resize":
-            keep_ratio = preprocessing["keep_ratio"]
-            max_h, max_w = preprocessing["max_height"], preprocessing["max_width"]
-            temp_img = img
-            h, w, c = temp_img.shape
-
-            ratio_h = max_h / h if max_h else 1
-            ratio_w = max_w / w if max_w else 1
-            if keep_ratio:
-                ratio_h = ratio_w = min(ratio_w, ratio_h)
-            new_h = min(max_h, int(h * ratio_h))
-            new_w = min(max_w, int(w * ratio_w))
-            temp_img = cv2.resize(temp_img, (new_w, new_h))
-            if len(temp_img.shape) == 2:
-                temp_img = np.expand_dims(temp_img, axis=2)
-
-            img = temp_img
-            resize_ratio = [ratio_h, ratio_w]
-
-        if preprocessing["type"] == "fixed_height":
-            new_h = preprocessing["height"]
-            temp_img = img
-            h, w, c = temp_img.shape
-            ratio = new_h / h
-            temp_img = cv2.resize(temp_img, (int(w*ratio), new_h))
-            if len(temp_img.shape) == 2:
-                temp_img = np.expand_dims(temp_img, axis=2)
-            img = temp_img
-            resize_ratio = [ratio, ratio]
-    if resize_ratio != [1, 1] and "raw_line_seg_label" in sample:
-        for li in range(len(sample["raw_line_seg_label"])):
-            for side, ratio in zip((["bottom", "top"], ["right", "left"]), resize_ratio):
-                for s in side:
-                    sample["raw_line_seg_label"][li][s] = sample["raw_line_seg_label"][li][s] * ratio
-
-    sample["img"] = img
-    sample["resize_ratio"] = resize_ratio
-    return sample
-
diff --git a/basic/generic_training_manager.py b/basic/generic_training_manager.py
deleted file mode 100644
index a8f41b28274678b23083123e33d2343bb6b942b9..0000000000000000000000000000000000000000
--- a/basic/generic_training_manager.py
+++ /dev/null
@@ -1,706 +0,0 @@
-import torch
-import os
-import sys
-import copy
-import json
-import torch.distributed as dist
-import torch.multiprocessing as mp
-import random
-import numpy as np
-from torch.utils.tensorboard import SummaryWriter
-from torch.nn.init import kaiming_uniform_
-from tqdm import tqdm
-from time import time
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.cuda.amp import GradScaler
-from basic.metric_manager import MetricManager
-from basic.scheduler import DropoutScheduler
-from datetime import date
-
-
-class GenericTrainingManager:
-
-    def __init__(self, params):
-        self.type = None
-        self.is_master = False
-        self.params = params
-        self.dropout_scheduler = None
-        self.models = {}
-        self.begin_time = None
-        self.dataset = None
-        self.dataset_name = list(self.params["dataset_params"]["datasets"].values())[0]
-        self.paths = None
-        self.latest_step = 0
-        self.latest_epoch = -1
-        self.latest_batch = 0
-        self.total_batch = 0
-        self.grad_acc_step = 0
-        self.latest_train_metrics = dict()
-        self.latest_valid_metrics = dict()
-        self.curriculum_info = dict()
-        self.curriculum_info["latest_valid_metrics"] = dict()
-        self.phase = None
-        self.max_mem_usage_by_epoch = list()
-        self.losses = list()
-        self.lr_values = list()
-
-        self.scaler = None
-
-        self.optimizers = dict()
-        self.optimizers_named_params_by_group = dict()
-        self.lr_schedulers = dict()
-        self.best = None
-        self.writer = None
-        self.metric_manager = dict()
-
-        self.init_hardware_config()
-        self.init_paths()
-        self.load_dataset()
-        self.params["model_params"]["use_amp"] = self.params["training_params"]["use_amp"]
-
-    def init_paths(self):
-        """
-        Create output folders for results and checkpoints
-        """
-        output_path = os.path.join("outputs", self.params["training_params"]["output_folder"])
-        os.makedirs(output_path, exist_ok=True)
-        checkpoints_path = os.path.join(output_path, "checkpoints")
-        os.makedirs(checkpoints_path, exist_ok=True)
-        results_path = os.path.join(output_path, "results")
-        os.makedirs(results_path, exist_ok=True)
-
-        self.paths = {
-            "results": results_path,
-            "checkpoints": checkpoints_path,
-            "output_folder": output_path
-        }
-
-    def load_dataset(self):
-        """
-        Load datasets, data samplers and data loaders
-        """
-        self.params["dataset_params"]["use_ddp"] = self.params["training_params"]["use_ddp"]
-        self.params["dataset_params"]["batch_size"] = self.params["training_params"]["batch_size"]
-        if "valid_batch_size" in self.params["training_params"]:
-            self.params["dataset_params"]["valid_batch_size"] = self.params["training_params"]["valid_batch_size"]
-        if "test_batch_size" in self.params["training_params"]:
-            self.params["dataset_params"]["test_batch_size"] = self.params["training_params"]["test_batch_size"]
-        self.params["dataset_params"]["num_gpu"] = self.params["training_params"]["nb_gpu"]
-        self.params["dataset_params"]["worker_per_gpu"] = 4 if "worker_per_gpu" not in self.params["dataset_params"] else self.params["dataset_params"]["worker_per_gpu"]
-        self.dataset = self.params["dataset_params"]["dataset_manager"](self.params["dataset_params"])
-        self.dataset.load_datasets()
-        self.dataset.load_ddp_samplers()
-        self.dataset.load_dataloaders()
-
-    def init_hardware_config(self):
-        # Debug mode
-        if self.params["training_params"]["force_cpu"]:
-            self.params["training_params"]["use_ddp"] = False
-            self.params["training_params"]["use_amp"] = False
-        # Manage Distributed Data Parallel & GPU usage
-        self.manual_seed = 1111 if "manual_seed" not in self.params["training_params"].keys() else \
-        self.params["training_params"]["manual_seed"]
-        self.ddp_config = {
-            "master": self.params["training_params"]["use_ddp"] and self.params["training_params"]["ddp_rank"] == 0,
-            "address": "localhost" if "ddp_addr" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_addr"],
-            "port": "11111" if "ddp_port" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_port"],
-            "backend": "nccl" if "ddp_backend" not in self.params["training_params"].keys() else self.params["training_params"]["ddp_backend"],
-            "rank": self.params["training_params"]["ddp_rank"],
-        }
-        self.is_master = self.ddp_config["master"] or not self.params["training_params"]["use_ddp"]
-        if self.params["training_params"]["force_cpu"]:
-            self.device = "cpu"
-        else:
-            if self.params["training_params"]["use_ddp"]:
-                self.device = torch.device(self.ddp_config["rank"])
-                self.params["dataset_params"]["ddp_rank"] = self.ddp_config["rank"]
-                self.launch_ddp()
-            else:
-                self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-        self.params["model_params"]["device"] = self.device.type
-        # Print GPU info
-        # global
-        if (self.params["training_params"]["use_ddp"] and self.ddp_config["master"]) or not self.params["training_params"]["use_ddp"]:
-            print("##################")
-            print("Available GPUS: {}".format(self.params["training_params"]["nb_gpu"]))
-            for i in range(self.params["training_params"]["nb_gpu"]):
-                print("Rank {}: {} {}".format(i, torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i)))
-            print("##################")
-        # local
-        print("Local GPU:")
-        if self.device != "cpu":
-            print("Rank {}: {} {}".format(self.params["training_params"]["ddp_rank"], torch.cuda.get_device_name(), torch.cuda.get_device_properties(self.device)))
-        else:
-            print("WORKING ON CPU !\n")
-        print("##################")
-
-    def load_model(self, reset_optimizer=False, strict=True):
-        """
-        Load model weights from scratch or from checkpoints
-        """
-        # Instantiate Model
-        for model_name in self.params["model_params"]["models"].keys():
-            self.models[model_name] = self.params["model_params"]["models"][model_name](self.params["model_params"])
-            self.models[model_name].to(self.device)  # To GPU or CPU
-            # make the model compatible with Distributed Data Parallel if used
-            if self.params["training_params"]["use_ddp"]:
-                self.models[model_name] = DDP(self.models[model_name], [self.ddp_config["rank"]])
-
-        # Handle curriculum dropout
-        if "dropout_scheduler" in self.params["model_params"]:
-            func = self.params["model_params"]["dropout_scheduler"]["function"]
-            T = self.params["model_params"]["dropout_scheduler"]["T"]
-            self.dropout_scheduler = DropoutScheduler(self.models, func, T)
-
-        self.scaler = GradScaler(enabled=self.params["training_params"]["use_amp"])
-
-        # Check if checkpoint exists
-        checkpoint = self.get_checkpoint()
-        if checkpoint is not None:
-            self.load_existing_model(checkpoint, strict=strict)
-        else:
-            self.init_new_model()
-
-        self.load_optimizers(checkpoint, reset_optimizer=reset_optimizer)
-
-        if self.is_master:
-            print("LOADED EPOCH: {}\n".format(self.latest_epoch), flush=True)
-
-    def get_checkpoint(self):
-        """
-        Seek if checkpoint exist, return None otherwise
-        """
-        if self.params["training_params"]["load_epoch"] in ("best", "last"):
-            for filename in os.listdir(self.paths["checkpoints"]):
-                if self.params["training_params"]["load_epoch"] in filename:
-                    return torch.load(os.path.join(self.paths["checkpoints"], filename))
-        return None
-
-    def load_existing_model(self, checkpoint, strict=True):
-        """
-        Load information and weights from previous training
-        """
-        self.load_save_info(checkpoint)
-        self.latest_epoch = checkpoint["epoch"]
-        if "step" in checkpoint:
-            self.latest_step = checkpoint["step"]
-        self.best = checkpoint["best"]
-        if "scaler_state_dict" in checkpoint:
-            self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
-        # Load model weights from past training
-        for model_name in self.models.keys():
-            self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(model_name)], strict=strict)
-
-    def init_new_model(self):
-        """
-        Initialize model
-        """
-        # Specific weights initialization if exists
-        for model_name in self.models.keys():
-            try:
-                self.models[model_name].init_weights()
-            except:
-                pass
-
-        # Handle transfer learning instructions
-        if self.params["model_params"]["transfer_learning"]:
-            # Iterates over models
-            for model_name in self.params["model_params"]["transfer_learning"].keys():
-                state_dict_name, path, learnable, strict = self.params["model_params"]["transfer_learning"][model_name]
-                # Loading pretrained weights file
-                checkpoint = torch.load(path)
-                try:
-                    # Load pretrained weights for model
-                    self.models[model_name].load_state_dict(checkpoint["{}_state_dict".format(state_dict_name)], strict=strict)
-                    print("transfered weights for {}".format(state_dict_name), flush=True)
-                except RuntimeError as e:
-                    print(e, flush=True)
-                    # if error, try to load each parts of the model (useful if only few layers are different)
-                    for key in checkpoint["{}_state_dict".format(state_dict_name)].keys():
-                        try:
-                            # for pre-training of decision layer
-                            if "end_conv" in key and "transfered_charset" in self.params["model_params"]:
-                                self.adapt_decision_layer_to_old_charset(model_name, key, checkpoint, state_dict_name)
-                            else:
-                                self.models[model_name].load_state_dict(
-                                    {key: checkpoint["{}_state_dict".format(state_dict_name)][key]}, strict=False)
-                        except RuntimeError as e:
-                            ## exception when adding linebreak token from pretraining
-                                print(e, flush=True)
-                # Set parameters no trainable
-                if not learnable:
-                    self.set_model_learnable(self.models[model_name], False)
-
-    def adapt_decision_layer_to_old_charset(self, model_name, key, checkpoint, state_dict_name):
-        """
-        Transfer learning of the decision learning in case of close charsets between pre-training and training
-        """
-        pretrained_chars = list()
-        weights = checkpoint["{}_state_dict".format(state_dict_name)][key]
-        new_size = list(weights.size())
-        new_size[0] = len(self.dataset.charset) + self.params["model_params"]["additional_tokens"]
-        new_weights = torch.zeros(new_size, device=weights.device, dtype=weights.dtype)
-        old_charset = checkpoint["charset"] if "charset" in checkpoint else self.params["model_params"]["old_charset"]
-        if not "bias" in key:
-            kaiming_uniform_(new_weights, nonlinearity="relu")
-        for i, c in enumerate(self.dataset.charset):
-            if c in old_charset:
-                new_weights[i] = weights[old_charset.index(c)]
-                pretrained_chars.append(c)
-        if "transfered_charset_last_is_ctc_blank" in self.params["model_params"] and self.params["model_params"]["transfered_charset_last_is_ctc_blank"]:
-            new_weights[-1] = weights[-1]
-            pretrained_chars.append("<blank>")
-        checkpoint["{}_state_dict".format(state_dict_name)][key] = new_weights
-        self.models[model_name].load_state_dict({key: checkpoint["{}_state_dict".format(state_dict_name)][key]}, strict=False)
-        print("Pretrained chars for {} ({}): {}".format(key, len(pretrained_chars), pretrained_chars))
-
-    def load_optimizers(self, checkpoint, reset_optimizer=False):
-        """
-        Load the optimizer of each model
-        """
-        for model_name in self.models.keys():
-            new_params = dict()
-            if checkpoint and "optimizer_named_params_{}".format(model_name) in checkpoint:
-                self.optimizers_named_params_by_group[model_name] = checkpoint["optimizer_named_params_{}".format(model_name)]
-                # for progressively growing models
-                for name, param in self.models[model_name].named_parameters():
-                    existing = False
-                    for gr in self.optimizers_named_params_by_group[model_name]:
-                        if name in gr:
-                            gr[name] = param
-                            existing = True
-                            break
-                    if not existing:
-                        new_params.update({name: param})
-            else:
-                self.optimizers_named_params_by_group[model_name] = [dict(), ]
-                self.optimizers_named_params_by_group[model_name][0].update(self.models[model_name].named_parameters())
-
-            # Instantiate optimizer
-            self.reset_optimizer(model_name)
-
-            # Handle learning rate schedulers
-            if "lr_schedulers" in self.params["training_params"] and self.params["training_params"]["lr_schedulers"]:
-                key = "all" if "all" in self.params["training_params"]["lr_schedulers"] else model_name
-                if key in self.params["training_params"]["lr_schedulers"]:
-                    self.lr_schedulers[model_name] = self.params["training_params"]["lr_schedulers"][key]["class"]\
-                        (self.optimizers[model_name], **self.params["training_params"]["lr_schedulers"][key]["args"])
-
-            # Load optimizer state from past training
-            if checkpoint and not reset_optimizer:
-                self.optimizers[model_name].load_state_dict(checkpoint["optimizer_{}_state_dict".format(model_name)])
-                # Load optimizer scheduler config from past training if used
-                if "lr_schedulers" in self.params["training_params"] and self.params["training_params"]["lr_schedulers"] \
-                        and "lr_scheduler_{}_state_dict".format(model_name) in checkpoint.keys():
-                    self.lr_schedulers[model_name].load_state_dict(checkpoint["lr_scheduler_{}_state_dict".format(model_name)])
-
-            # for progressively growing models, keeping learning rate
-            if checkpoint and new_params:
-                self.optimizers_named_params_by_group[model_name].append(new_params)
-                self.optimizers[model_name].add_param_group({"params": list(new_params.values())})
-
-    @staticmethod
-    def set_model_learnable(model, learnable=True):
-        for p in list(model.parameters()):
-            p.requires_grad = learnable
-
-    def save_model(self, epoch, name, keep_weights=False):
-        """
-        Save model weights and training info for curriculum learning or learning rate for instance
-        """
-        if not self.is_master:
-            return
-        to_del = []
-        for filename in os.listdir(self.paths["checkpoints"]):
-            if name in filename:
-                to_del.append(os.path.join(self.paths["checkpoints"], filename))
-        path = os.path.join(self.paths["checkpoints"], "{}_{}.pt".format(name, epoch))
-        content = {
-            'optimizers_named_params': self.optimizers_named_params_by_group,
-            'epoch': epoch,
-            'step': self.latest_step,
-            "scaler_state_dict": self.scaler.state_dict(),
-            'best': self.best,
-            "charset": self.dataset.charset
-        }
-        for model_name in self.optimizers:
-            content['optimizer_{}_state_dict'.format(model_name)] = self.optimizers[model_name].state_dict()
-        for model_name in self.lr_schedulers:
-            content["lr_scheduler_{}_state_dict".format(model_name)] = self.lr_schedulers[model_name].state_dict()
-        content = self.add_save_info(content)
-        for model_name in self.models.keys():
-            content["{}_state_dict".format(model_name)] = self.models[model_name].state_dict()
-        torch.save(content, path)
-        if not keep_weights:
-            for path_to_del in to_del:
-                if path_to_del != path:
-                    os.remove(path_to_del)
-
-    def reset_optimizers(self):
-        """
-        Reset learning rate of all optimizers
-        """
-        for model_name in self.models.keys():
-            self.reset_optimizer(model_name)
-
-    def reset_optimizer(self, model_name):
-        """
-        Reset optimizer learning rate for given model
-        """
-        params = list(self.optimizers_named_params_by_group[model_name][0].values())
-        key = "all" if "all" in self.params["training_params"]["optimizers"] else model_name
-        self.optimizers[model_name] = self.params["training_params"]["optimizers"][key]["class"](params, **self.params["training_params"]["optimizers"][key]["args"])
-        for i in range(1, len(self.optimizers_named_params_by_group[model_name])):
-            self.optimizers[model_name].add_param_group({"params": list(self.optimizers_named_params_by_group[model_name][i].values())})
-
-    def save_params(self):
-        """
-        Output text file containing a summary of all hyperparameters chosen for the training
-        """
-        def compute_nb_params(module):
-            return sum([np.prod(p.size()) for p in list(module.parameters())])
-
-        def class_to_str_dict(my_dict):
-            for key in my_dict.keys():
-                if callable(my_dict[key]):
-                    my_dict[key] = my_dict[key].__name__
-                elif isinstance(my_dict[key], np.ndarray):
-                    my_dict[key] = my_dict[key].tolist()
-                elif isinstance(my_dict[key], dict):
-                    my_dict[key] = class_to_str_dict(my_dict[key])
-            return my_dict
-
-        path = os.path.join(self.paths["results"], "params")
-        if os.path.isfile(path):
-            return
-        params = copy.deepcopy(self.params)
-        params = class_to_str_dict(params)
-        params["date"] = date.today().strftime("%d/%m/%Y")
-        total_params = 0
-        for model_name in self.models.keys():
-            current_params = compute_nb_params(self.models[model_name])
-            params["model_params"]["models"][model_name] = [params["model_params"]["models"][model_name], "{:,}".format(current_params)]
-            total_params += current_params
-        params["model_params"]["total_params"] = "{:,}".format(total_params)
-
-        params["hardware"] = dict()
-        if self.device != "cpu":
-            for i in range(self.params["training_params"]["nb_gpu"]):
-                params["hardware"][str(i)] = "{} {}".format(torch.cuda.get_device_name(i), torch.cuda.get_device_properties(i))
-        else:
-            params["hardware"]["0"] = "CPU"
-        params["software"] = {
-            "python_version": sys.version,
-            "pytorch_version": torch.__version__,
-            "cuda_version": torch.version.cuda,
-            "cudnn_version": torch.backends.cudnn.version(),
-        }
-        with open(path, 'w') as f:
-            json.dump(params, f, indent=4)
-
-    def backward_loss(self, loss, retain_graph=False):
-        self.scaler.scale(loss).backward(retain_graph=retain_graph)
-
-    def step_optimizers(self, increment_step=True, names=None):
-        for model_name in self.optimizers:
-            if names and model_name not in names:
-                continue
-            if "gradient_clipping" in self.params["training_params"] and model_name in self.params["training_params"]["gradient_clipping"]["models"]:
-                self.scaler.unscale_(self.optimizers[model_name])
-                torch.nn.utils.clip_grad_norm_(self.models[model_name].parameters(), self.params["training_params"]["gradient_clipping"]["max"])
-            self.scaler.step(self.optimizers[model_name])
-        self.scaler.update()
-        self.latest_step += 1
-
-    def zero_optimizers(self, set_to_none=True):
-        for model_name in self.optimizers:
-            self.zero_optimizer(model_name, set_to_none)
-
-    def zero_optimizer(self, model_name, set_to_none=True):
-        self.optimizers[model_name].zero_grad(set_to_none=set_to_none)
-
-    def train(self):
-        """
-        Main training loop
-        """
-        # init tensorboard file and output param summary file
-        if self.is_master:
-            self.writer = SummaryWriter(self.paths["results"])
-            self.save_params()
-        # init variables
-        self.begin_time = time()
-        focus_metric_name = self.params["training_params"]["focus_metric"]
-        nb_epochs = self.params["training_params"]["max_nb_epochs"]
-        interval_save_weights = self.params["training_params"]["interval_save_weights"]
-        metric_names = self.params["training_params"]["train_metrics"]
-
-        display_values = None
-        # init curriculum learning
-        if "curriculum_learning" in self.params["training_params"].keys() and self.params["training_params"]["curriculum_learning"]:
-            self.init_curriculum()
-        # perform epochs
-        for num_epoch in range(self.latest_epoch+1, nb_epochs):
-            self.dataset.train_dataset.training_info = {
-                "epoch": self.latest_epoch,
-                "step": self.latest_step
-            }
-            self.phase = "train"
-            # Check maximum training time stop condition
-            if self.params["training_params"]["max_training_time"] and time() - self.begin_time > self.params["training_params"]["max_training_time"]:
-                break
-            # set models trainable
-            for model_name in self.models.keys():
-                self.models[model_name].train()
-            self.latest_epoch = num_epoch
-            if self.dataset.train_dataset.curriculum_config:
-                self.dataset.train_dataset.curriculum_config["epoch"] = self.latest_epoch
-            # init epoch metrics values
-            self.metric_manager["train"] = MetricManager(metric_names=metric_names, dataset_name=self.dataset_name)
-
-            with tqdm(total=len(self.dataset.train_loader.dataset)) as pbar:
-                pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs))
-                # iterates over mini-batch data
-                for ind_batch, batch_data in enumerate(self.dataset.train_loader):
-                    self.latest_batch = ind_batch + 1
-                    self.total_batch += 1
-                    # train on batch data and compute metrics
-                    batch_values = self.train_batch(batch_data, metric_names)
-                    batch_metrics = self.metric_manager["train"].compute_metrics(batch_values, metric_names)
-                    batch_metrics["names"] = batch_data["names"]
-                    batch_metrics["ids"] = batch_data["ids"]
-                    # Merge metrics if Distributed Data Parallel is used
-                    if self.params["training_params"]["use_ddp"]:
-                        batch_metrics = self.merge_ddp_metrics(batch_metrics)
-                    # Update learning rate via scheduler if one is used
-                    if self.params["training_params"]["lr_schedulers"]:
-                        for model_name in self.models:
-                            key = "all" if "all" in self.params["training_params"]["lr_schedulers"] else model_name
-                            if model_name in self.lr_schedulers and ind_batch % self.params["training_params"]["lr_schedulers"][key]["step_interval"] == 0:
-                                self.lr_schedulers[model_name].step(len(batch_metrics["names"]))
-                                if "lr" in metric_names:
-                                    self.writer.add_scalar("lr_{}".format(model_name), self.lr_schedulers[model_name].lr, self.lr_schedulers[model_name].step_num)
-                    # Update dropout scheduler if used
-                    if self.dropout_scheduler:
-                        self.dropout_scheduler.step(len(batch_metrics["names"]))
-                        self.dropout_scheduler.update_dropout_rate()
-
-                    # Add batch metrics values to epoch metrics values
-                    self.metric_manager["train"].update_metrics(batch_metrics)
-                    display_values = self.metric_manager["train"].get_display_values()
-                    pbar.set_postfix(values=str(display_values))
-                    pbar.update(len(batch_data["names"]))
-
-            # log metrics in tensorboard file
-            if self.is_master:
-                for key in display_values.keys():
-                    self.writer.add_scalar('{}_{}'.format(self.params["dataset_params"]["train"]["name"], key), display_values[key], num_epoch)
-            self.latest_train_metrics = display_values
-
-            # evaluate and compute metrics for valid sets
-            if self.params["training_params"]["eval_on_valid"] and num_epoch % self.params["training_params"]["eval_on_valid_interval"] == 0:
-                for valid_set_name in self.dataset.valid_loaders.keys():
-                    # evaluate set and compute metrics
-                    eval_values = self.evaluate(valid_set_name)
-                    self.latest_valid_metrics = eval_values
-                    # log valid metrics in tensorboard file
-                    if self.is_master:
-                        for key in eval_values.keys():
-                            self.writer.add_scalar('{}_{}'.format(valid_set_name, key), eval_values[key], num_epoch)
-                        if valid_set_name == self.params["training_params"]["set_name_focus_metric"] and (self.best is None or \
-                                (eval_values[focus_metric_name] <= self.best and self.params["training_params"]["expected_metric_value"] == "low") or\
-                                (eval_values[focus_metric_name] >= self.best and self.params["training_params"]["expected_metric_value"] == "high")):
-                            self.save_model(epoch=num_epoch, name="best")
-                            self.best = eval_values[focus_metric_name]
-
-            # Handle curriculum learning update
-            if self.dataset.train_dataset.curriculum_config:
-                self.check_and_update_curriculum()
-
-            if "curriculum_model" in self.params["model_params"] and self.params["model_params"]["curriculum_model"]:
-                self.update_curriculum_model()
-
-            # save model weights
-            if self.is_master:
-                self.save_model(epoch=num_epoch, name="last")
-                if interval_save_weights and num_epoch % interval_save_weights == 0:
-                    self.save_model(epoch=num_epoch, name="weigths", keep_weights=True)
-                self.writer.flush()
-
-    def evaluate(self, set_name, **kwargs):
-        """
-        Main loop for validation
-        """
-        self.phase = "eval"
-        loader = self.dataset.valid_loaders[set_name]
-        # Set models in eval mode
-        for model_name in self.models.keys():
-            self.models[model_name].eval()
-        metric_names = self.params["training_params"]["eval_metrics"]
-        display_values = None
-
-        # initialize epoch metrics
-        self.metric_manager[set_name] = MetricManager(metric_names, dataset_name=self.dataset_name)
-        with tqdm(total=len(loader.dataset)) as pbar:
-            pbar.set_description("Evaluation E{}".format(self.latest_epoch))
-            with torch.no_grad():
-                # iterate over batch data
-                for ind_batch, batch_data in enumerate(loader):
-                    self.latest_batch = ind_batch + 1
-                    # eval batch data and compute metrics
-                    batch_values = self.evaluate_batch(batch_data, metric_names)
-                    batch_metrics = self.metric_manager[set_name].compute_metrics(batch_values, metric_names)
-                    batch_metrics["names"] = batch_data["names"]
-                    batch_metrics["ids"] = batch_data["ids"]
-                    # merge metrics values if Distributed Data Parallel is used
-                    if self.params["training_params"]["use_ddp"]:
-                        batch_metrics = self.merge_ddp_metrics(batch_metrics)
-
-                    # add batch metrics to epoch metrics
-                    self.metric_manager[set_name].update_metrics(batch_metrics)
-                    display_values = self.metric_manager[set_name].get_display_values()
-
-                    pbar.set_postfix(values=str(display_values))
-                    pbar.update(len(batch_data["names"]))
-        if "cer_by_nb_cols" in metric_names:
-            self.log_cer_by_nb_cols(set_name)
-        return display_values
-
-    def predict(self, custom_name, sets_list, metric_names, output=False):
-        """
-        Main loop for evaluation
-        """
-        self.phase = "predict"
-        metric_names = metric_names.copy()
-        self.dataset.generate_test_loader(custom_name, sets_list)
-        loader = self.dataset.test_loaders[custom_name]
-        # Set models in eval mode
-        for model_name in self.models.keys():
-            self.models[model_name].eval()
-
-        # initialize epoch metrics
-        self.metric_manager[custom_name] = MetricManager(metric_names, self.dataset_name)
-
-        with tqdm(total=len(loader.dataset)) as pbar:
-            pbar.set_description("Prediction")
-            with torch.no_grad():
-                for ind_batch, batch_data in enumerate(loader):
-                    # iterates over batch data
-                    self.latest_batch = ind_batch + 1
-                    # eval batch data and compute metrics
-                    batch_values = self.evaluate_batch(batch_data, metric_names)
-                    batch_metrics = self.metric_manager[custom_name].compute_metrics(batch_values, metric_names)
-                    batch_metrics["names"] = batch_data["names"]
-                    batch_metrics["ids"] = batch_data["ids"]
-                    # merge batch metrics if Distributed Data Parallel is used
-                    if self.params["training_params"]["use_ddp"]:
-                        batch_metrics = self.merge_ddp_metrics(batch_metrics)
-
-                    # add batch metrics to epoch metrics
-                    self.metric_manager[custom_name].update_metrics(batch_metrics)
-                    display_values = self.metric_manager[custom_name].get_display_values()
-
-                    pbar.set_postfix(values=str(display_values))
-                    pbar.update(len(batch_data["names"]))
-
-        self.dataset.remove_test_dataset(custom_name)
-        # output metrics values if requested
-        if output:
-            if "pred" in metric_names:
-                self.output_pred(custom_name)
-            metrics = self.metric_manager[custom_name].get_display_values(output=True)
-            path = os.path.join(self.paths["results"], "predict_{}_{}.txt".format(custom_name, self.latest_epoch))
-            with open(path, "w") as f:
-                for metric_name in metrics.keys():
-                    f.write("{}: {}\n".format(metric_name, metrics[metric_name]))
-
-    def output_pred(self, name):
-        path = os.path.join(self.paths["results"], "pred_{}_{}.txt".format(name, self.latest_epoch))
-        pred = "\n".join(self.metric_manager[name].get("pred"))
-        with open(path, "w") as f:
-            f.write(pred)
-
-    def launch_ddp(self):
-        """
-        Initialize Distributed Data Parallel system
-        """
-        mp.set_start_method('fork', force=True)
-        os.environ['MASTER_ADDR'] = self.ddp_config["address"]
-        os.environ['MASTER_PORT'] = str(self.ddp_config["port"])
-        dist.init_process_group(self.ddp_config["backend"], rank=self.ddp_config["rank"], world_size=self.params["training_params"]["nb_gpu"])
-        torch.cuda.set_device(self.ddp_config["rank"])
-        random.seed(self.manual_seed)
-        np.random.seed(self.manual_seed)
-        torch.manual_seed(self.manual_seed)
-        torch.cuda.manual_seed(self.manual_seed)
-
-    def merge_ddp_metrics(self, metrics):
-        """
-        Merge metrics when Distributed Data Parallel is used
-        """
-        for metric_name in metrics.keys():
-            if metric_name in ["edit_words", "nb_words", "edit_chars", "nb_chars", "edit_chars_force_len",
-                               "edit_chars_curr", "nb_chars_curr", "ids"]:
-                metrics[metric_name] = self.cat_ddp_metric(metrics[metric_name])
-            elif metric_name in ["nb_samples", "loss", "loss_ce", "loss_ctc", "loss_ce_end"]:
-                metrics[metric_name] = self.sum_ddp_metric(metrics[metric_name], average=False)
-        return metrics
-
-    def sum_ddp_metric(self, metric, average=False):
-        """
-        Sum metrics for Distributed Data Parallel
-        """
-        sum = torch.tensor(metric[0]).to(self.device)
-        dist.all_reduce(sum, op=dist.ReduceOp.SUM)
-        if average:
-            sum.true_divide(dist.get_world_size())
-        return [sum.item(), ]
-
-    def cat_ddp_metric(self, metric):
-        """
-        Concatenate metrics for Distributed Data Parallel
-        """
-        tensor = torch.tensor(metric).unsqueeze(0).to(self.device)
-        res = [torch.zeros(tensor.size()).long().to(self.device) for _ in range(dist.get_world_size())]
-        dist.all_gather(res, tensor)
-        return list(torch.cat(res, dim=0).flatten().cpu().numpy())
-
-    @staticmethod
-    def cleanup():
-        dist.destroy_process_group()
-
-    def train_batch(self, batch_data, metric_names):
-        raise NotImplementedError
-
-    def evaluate_batch(self, batch_data, metric_names):
-        raise NotImplementedError
-
-    def init_curriculum(self):
-        raise NotImplementedError
-
-    def update_curriculum(self):
-        raise NotImplementedError
-
-    def add_checkpoint_info(self, load_mode="last", **kwargs):
-        for filename in os.listdir(self.paths["checkpoints"]):
-            if load_mode in filename:
-                checkpoint_path = os.path.join(self.paths["checkpoints"], filename)
-                checkpoint = torch.load(checkpoint_path)
-                for key in kwargs.keys():
-                    checkpoint[key] = kwargs[key]
-                torch.save(checkpoint, checkpoint_path)
-            return
-        self.save_model(self.latest_epoch, "last")
-
-    def load_save_info(self, info_dict):
-        """
-        Load curriculum info from saved model info
-        """
-        if "curriculum_config" in info_dict.keys():
-            self.dataset.train_dataset.curriculum_config = info_dict["curriculum_config"]
-
-    def add_save_info(self, info_dict):
-        """
-        Add curriculum info to model info to be saved
-        """
-        info_dict["curriculum_config"] = self.dataset.train_dataset.curriculum_config
-        return info_dict
\ No newline at end of file
diff --git a/basic/metric_manager.py b/basic/metric_manager.py
deleted file mode 100644
index 6c8576863a2d26a85c2d5dbc4321323dedf870a0..0000000000000000000000000000000000000000
--- a/basic/metric_manager.py
+++ /dev/null
@@ -1,538 +0,0 @@
-
-from Datasets.dataset_formatters.rimes_formatter import SEM_MATCHING_TOKENS as RIMES_MATCHING_TOKENS
-from Datasets.dataset_formatters.read2016_formatter import SEM_MATCHING_TOKENS as READ_MATCHING_TOKENS
-from Datasets.dataset_formatters.simara_formatter import SEM_MATCHING_TOKENS as SIMARA_MATCHING_TOKENS
-import re
-import networkx as nx
-import editdistance
-import numpy as np
-from dan.post_processing import PostProcessingModuleREAD, PostProcessingModuleRIMES, PostProcessingModuleSIMARA
-
-
-class MetricManager:
-
-    def __init__(self, metric_names, dataset_name):
-        self.dataset_name = dataset_name
-        if "READ" in dataset_name and "page" in dataset_name:
-            self.post_processing_module = PostProcessingModuleREAD
-            self.matching_tokens = READ_MATCHING_TOKENS
-            self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_read
-        elif "RIMES" in dataset_name and "page" in dataset_name:
-            self.post_processing_module = PostProcessingModuleRIMES
-            self.matching_tokens = RIMES_MATCHING_TOKENS
-            self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_rimes
-        elif "simara" in dataset_name and "page" in dataset_name:
-            self.post_processing_module = PostProcessingModuleSIMARA
-            self.matching_tokens = SIMARA_MATCHING_TOKENS
-            self.edit_and_num_edge_nodes = edit_and_num_items_for_ged_from_str_simara
-        else:
-            self.matching_tokens = dict()
-
-        self.layout_tokens = "".join(list(self.matching_tokens.keys()) + list(self.matching_tokens.values()))
-        if len(self.layout_tokens) == 0:
-            self.layout_tokens = None
-        self.metric_names = metric_names
-        self.epoch_metrics = None
-
-        self.linked_metrics = {
-            "cer": ["edit_chars", "nb_chars"],
-            "wer": ["edit_words", "nb_words"],
-            "loer": ["edit_graph", "nb_nodes_and_edges", "nb_pp_op_layout", "nb_gt_layout_token"],
-            "precision": ["precision", "weights"],
-            "map_cer_per_class": ["map_cer", ],
-            "layout_precision_per_class_per_threshold": ["map_cer", ],
-        }
-
-        self.init_metrics()
-
-    def init_metrics(self):
-        """
-        Initialization of the metrics specified in metrics_name
-        """
-        self.epoch_metrics = {
-            "nb_samples": list(),
-            "names": list(),
-            "ids": list(),
-        }
-
-        for metric_name in self.metric_names:
-            if metric_name in self.linked_metrics:
-                for linked_metric_name in self.linked_metrics[metric_name]:
-                    if linked_metric_name not in self.epoch_metrics.keys():
-                        self.epoch_metrics[linked_metric_name] = list()
-            else:
-                self.epoch_metrics[metric_name] = list()
-
-    def update_metrics(self, batch_metrics):
-        """
-        Add batch metrics to the metrics
-        """
-        for key in batch_metrics.keys():
-            if key in self.epoch_metrics:
-                self.epoch_metrics[key] += batch_metrics[key]
-
-    def get_display_values(self, output=False):
-        """
-        format metrics values for shell display purposes
-        """
-        metric_names = self.metric_names.copy()
-        if output:
-            metric_names.extend(["nb_samples"])
-        display_values = dict()
-        for metric_name in metric_names:
-            value = None
-            if output:
-                if metric_name in ["nb_samples", "weights"]:
-                    value = np.sum(self.epoch_metrics[metric_name])
-                elif metric_name in ["time", ]:
-                    total_time = np.sum(self.epoch_metrics[metric_name])
-                    sample_time = total_time / np.sum(self.epoch_metrics["nb_samples"])
-                    display_values["sample_time"] = round(sample_time, 4)
-                    value = total_time
-                elif metric_name == "loer":
-                    display_values["pper"] = round(np.sum(self.epoch_metrics["nb_pp_op_layout"]) / np.sum(self.epoch_metrics["nb_gt_layout_token"]), 4)
-                elif metric_name == "map_cer_per_class":
-                    value = compute_global_mAP_per_class(self.epoch_metrics["map_cer"])
-                    for key in value.keys():
-                        display_values["map_cer_" + key] = round(value[key], 4)
-                    continue
-                elif metric_name == "layout_precision_per_class_per_threshold":
-                    value = compute_global_precision_per_class_per_threshold(self.epoch_metrics["map_cer"])
-                    for key_class in value.keys():
-                        for threshold in value[key_class].keys():
-                            display_values["map_cer_{}_{}".format(key_class, threshold)] = round(
-                                value[key_class][threshold], 4)
-                    continue
-            if metric_name == "cer":
-                value = np.sum(self.epoch_metrics["edit_chars"]) / np.sum(self.epoch_metrics["nb_chars"])
-                if output:
-                    display_values["nb_chars"] = np.sum(self.epoch_metrics["nb_chars"])
-            elif metric_name == "wer":
-                value = np.sum(self.epoch_metrics["edit_words"]) / np.sum(self.epoch_metrics["nb_words"])
-                if output:
-                    display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"])
-            elif metric_name in ["loss", "loss_ctc", "loss_ce", "syn_max_lines"]:
-                value = np.average(self.epoch_metrics[metric_name], weights=np.array(self.epoch_metrics["nb_samples"]))
-            elif metric_name == "map_cer":
-                value = compute_global_mAP(self.epoch_metrics[metric_name])
-            elif metric_name == "loer":
-                value = np.sum(self.epoch_metrics["edit_graph"]) / np.sum(self.epoch_metrics["nb_nodes_and_edges"])
-            elif value is None:
-                continue
-
-            display_values[metric_name] = round(value, 4)
-        return display_values
-
-    def compute_metrics(self, values, metric_names):
-        metrics = {
-            "nb_samples": [values["nb_samples"], ],
-        }
-        for v in ["weights", "time"]:
-            if v in values:
-                metrics[v] = [values[v]]
-        for metric_name in metric_names:
-            if metric_name == "cer":
-                metrics["edit_chars"] = [edit_cer_from_string(u, v, self.layout_tokens) for u, v in zip(values["str_y"], values["str_x"])]
-                metrics["nb_chars"] = [nb_chars_cer_from_string(gt, self.layout_tokens) for gt in values["str_y"]]
-            elif metric_name == "wer":
-                split_gt = [format_string_for_wer(gt, self.layout_tokens) for gt in values["str_y"]]
-                split_pred = [format_string_for_wer(pred, self.layout_tokens) for pred in values["str_x"]]
-                metrics["edit_words"] = [edit_wer_from_formatted_split_text(gt, pred) for (gt, pred) in zip(split_gt, split_pred)]
-                metrics["nb_words"] = [len(gt) for gt in split_gt]
-            elif metric_name in ["loss_ctc", "loss_ce", "loss", "syn_max_lines", ]:
-                metrics[metric_name] = [values[metric_name], ]
-            elif metric_name == "map_cer":
-                pp_pred = list()
-                pp_score = list()
-                for pred, score in zip(values["str_x"], values["confidence_score"]):
-                    pred_score = self.post_processing_module().post_process(pred, score)
-                    pp_pred.append(pred_score[0])
-                    pp_score.append(pred_score[1])
-                metrics[metric_name] = [compute_layout_mAP_per_class(y, x, conf, self.matching_tokens) for x, conf, y in zip(pp_pred, pp_score, values["str_y"])]
-            elif metric_name == "loer":
-                pp_pred = list()
-                metrics["nb_pp_op_layout"] = list()
-                for pred in values["str_x"]:
-                    pp_module = self.post_processing_module()
-                    pp_pred.append(pp_module.post_process(pred))
-                    metrics["nb_pp_op_layout"].append(pp_module.num_op)
-                metrics["nb_gt_layout_token"] = [len(keep_only_tokens(str_x, self.layout_tokens)) for str_x in values["str_x"]]
-                edit_and_num_items = [self.edit_and_num_edge_nodes(y, x) for x, y in zip(pp_pred, values["str_y"])]
-                metrics["edit_graph"], metrics["nb_nodes_and_edges"] = [ei[0] for ei in edit_and_num_items], [ei[1] for ei in edit_and_num_items]
-        return metrics
-
-    def get(self, name):
-        return self.epoch_metrics[name]
-
-
-def keep_only_tokens(str, tokens):
-    """
-    Remove all but layout tokens from string
-    """
-    return re.sub('([^' + tokens + '])', '', str)
-
-
-def keep_all_but_tokens(str, tokens):
-    """
-    Remove all layout tokens from string
-    """
-    return re.sub('([' + tokens + '])', '', str)
-
-
-def edit_cer_from_string(gt, pred, layout_tokens=None):
-    """
-    Format and compute edit distance between two strings at character level
-    """
-    gt = format_string_for_cer(gt, layout_tokens)
-    pred = format_string_for_cer(pred, layout_tokens)
-    return editdistance.eval(gt, pred)
-
-
-def nb_chars_cer_from_string(gt, layout_tokens=None):
-    """
-    Compute length after formatting of ground truth string
-    """
-    return len(format_string_for_cer(gt, layout_tokens))
-
-
-def edit_wer_from_string(gt, pred, layout_tokens=None):
-    """
-    Format and compute edit distance between two strings at word level
-    """
-    split_gt = format_string_for_wer(gt, layout_tokens)
-    split_pred = format_string_for_wer(pred, layout_tokens)
-    return edit_wer_from_formatted_split_text(split_gt, split_pred)
-
-
-def format_string_for_wer(str, layout_tokens):
-    """
-    Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
-    """
-    str = re.sub('([\[\]{}/\\()\"\'&+*=<>?.;:,!\-—_€#%°])', r' \1 ', str)  # punctuation processed as word
-    if layout_tokens is not None:
-        str = keep_all_but_tokens(str, layout_tokens)  # remove layout tokens from metric
-    str = re.sub('([ \n])+', " ", str).strip()  # keep only one space character
-    return str.split(" ")
-
-
-def format_string_for_cer(str, layout_tokens):
-    """
-    Format string for CER computation: remove layout tokens and extra spaces
-    """
-    if layout_tokens is not None:
-        str = keep_all_but_tokens(str, layout_tokens)  # remove layout tokens from metric
-    str = re.sub('([\n])+', "\n", str)  # remove consecutive line breaks
-    str = re.sub('([ ])+', " ", str).strip()  # remove consecutive spaces
-    return str
-
-
-def edit_wer_from_formatted_split_text(gt, pred):
-    """
-    Compute edit distance at word level from formatted string as list
-    """
-    return editdistance.eval(gt, pred)
-
-
-def extract_by_tokens(input_str, begin_token, end_token, associated_score=None, order_by_score=False):
-    """
-    Extract list of text regions by begin and end tokens
-    Order the list by confidence score
-    """
-    if order_by_score:
-        assert associated_score is not None
-    res = list()
-    for match in re.finditer("{}[^{}]*{}".format(begin_token, end_token, end_token), input_str):
-        begin, end = match.regs[0]
-        if order_by_score:
-            res.append({
-                "confidence": np.mean([associated_score[begin], associated_score[end-1]]),
-                "content": input_str[begin+1:end-1]
-            })
-        else:
-            res.append(input_str[begin+1:end-1])
-    if order_by_score:
-        res = sorted(res, key=lambda x: x["confidence"], reverse=True)
-        res = [r["content"] for r in res]
-    return res
-
-
-def compute_layout_precision_per_threshold(gt, pred, score, begin_token, end_token, layout_tokens, return_weight=True):
-    """
-    Compute average precision of a given class for CER threshold from 5% to 50% with a step of 5%
-    """
-    pred_list = extract_by_tokens(pred, begin_token, end_token, associated_score=score, order_by_score=True)
-    gt_list = extract_by_tokens(gt, begin_token, end_token)
-    pred_list = [keep_all_but_tokens(p, layout_tokens) for p in pred_list]
-    gt_list = [keep_all_but_tokens(gt, layout_tokens) for gt in gt_list]
-    precision_per_threshold = [compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold/100) for threshold in range(5, 51, 5)]
-    if return_weight:
-        return precision_per_threshold, len(gt_list)
-    return precision_per_threshold
-
-
-def compute_layout_AP_for_given_threshold(gt_list, pred_list, threshold):
-    """
-    Compute average precision of a given class for a given CER threshold
-    """
-    remaining_gt_list = gt_list.copy()
-    num_true = len(gt_list)
-    correct = np.zeros((len(pred_list)), dtype=np.bool)
-    for i, pred in enumerate(pred_list):
-        if len(remaining_gt_list) == 0:
-            break
-        cer_with_gt = [edit_cer_from_string(gt, pred)/nb_chars_cer_from_string(gt) for gt in remaining_gt_list]
-        cer, ind = np.min(cer_with_gt), np.argmin(cer_with_gt)
-        if cer <= threshold:
-            correct[i] = True
-            del remaining_gt_list[ind]
-    precision = np.cumsum(correct, dtype=np.int) / np.arange(1, len(pred_list)+1)
-    recall = np.cumsum(correct, dtype=np.int) / num_true
-    max_precision_from_recall = np.maximum.accumulate(precision[::-1])[::-1]
-    recall_diff = (recall - np.concatenate([np.array([0, ]), recall[:-1]]))
-    P = np.sum(recall_diff * max_precision_from_recall)
-    return P
-
-
-def compute_layout_mAP_per_class(gt, pred, score, tokens):
-    """
-    Compute the mAP_cer for each class for a given sample
-    """
-    layout_tokens = "".join(list(tokens.keys()))
-    AP_per_class = dict()
-    for token in tokens.keys():
-        if token in gt:
-            AP_per_class[token] = compute_layout_precision_per_threshold(gt, pred, score, token, tokens[token], layout_tokens=layout_tokens)
-    return AP_per_class
-
-
-def compute_global_mAP(list_AP_per_class):
-    """
-    Compute the global mAP_cer for several samples
-    """
-    weights_per_doc = list()
-    mAP_per_doc = list()
-    for doc_AP_per_class in list_AP_per_class:
-        APs = np.array([np.mean(doc_AP_per_class[key][0]) for key in doc_AP_per_class.keys()])
-        weights = np.array([doc_AP_per_class[key][1] for key in doc_AP_per_class.keys()])
-        if np.sum(weights) == 0:
-            mAP_per_doc.append(0)
-        else:
-            mAP_per_doc.append(np.average(APs, weights=weights))
-        weights_per_doc.append(np.sum(weights))
-    if np.sum(weights_per_doc) == 0:
-        return 0
-    return np.average(mAP_per_doc, weights=weights_per_doc)
-
-
-def compute_global_mAP_per_class(list_AP_per_class):
-    """
-    Compute the mAP_cer per class for several samples
-    """
-    mAP_per_class = dict()
-    for doc_AP_per_class in list_AP_per_class:
-        for key in doc_AP_per_class.keys():
-            if key not in mAP_per_class:
-                mAP_per_class[key] = {
-                    "AP": list(),
-                    "weights": list()
-                }
-            mAP_per_class[key]["AP"].append(np.mean(doc_AP_per_class[key][0]))
-            mAP_per_class[key]["weights"].append(doc_AP_per_class[key][1])
-    for key in mAP_per_class.keys():
-        mAP_per_class[key] = np.average(mAP_per_class[key]["AP"], weights=mAP_per_class[key]["weights"])
-    return mAP_per_class
-
-
-def compute_global_precision_per_class_per_threshold(list_AP_per_class):
-    """
-    Compute the mAP_cer per class and per threshold for several samples
-    """
-    mAP_per_class = dict()
-    for doc_AP_per_class in list_AP_per_class:
-        for key in doc_AP_per_class.keys():
-            if key not in mAP_per_class:
-                mAP_per_class[key] = dict()
-                for threshold in range(5, 51, 5):
-                    mAP_per_class[key][threshold] = {
-                        "precision": list(),
-                        "weights": list()
-                    }
-            for i, threshold in enumerate(range(5, 51, 5)):
-                mAP_per_class[key][threshold]["precision"].append(np.mean(doc_AP_per_class[key][0][i]))
-                mAP_per_class[key][threshold]["weights"].append(doc_AP_per_class[key][1])
-    for key_class in mAP_per_class.keys():
-        for threshold in mAP_per_class[key_class]:
-            mAP_per_class[key_class][threshold] = np.average(mAP_per_class[key_class][threshold]["precision"], weights=mAP_per_class[key_class][threshold]["weights"])
-    return mAP_per_class
-
-
-def str_to_graph_read(str):
-    """
-    Compute graph from string of layout tokens for the READ 2016 dataset at single-page and double-page levels
-    """
-    begin_layout_tokens = "".join(list(READ_MATCHING_TOKENS.keys()))
-    layout_token_sequence = keep_only_tokens(str, begin_layout_tokens)
-    g = nx.DiGraph()
-    g.add_node("D", type="document", level=4, page=0)
-    num = {
-        "â“Ÿ": 0,
-        "ⓐ": 0,
-        "â“‘": 0,
-        "ⓝ": 0,
-        "â“¢": 0
-    }
-    previous_top_level_node = None
-    previous_middle_level_node = None
-    previous_low_level_node = None
-    for ind, c in enumerate(layout_token_sequence):
-        num[c] += 1
-        if c == "â“Ÿ":
-            node_name = "P_{}".format(num[c])
-            g.add_node(node_name, type="page", level=3, page=num["â“Ÿ"])
-            g.add_edge("D", node_name)
-            if previous_top_level_node:
-                g.add_edge(previous_top_level_node, node_name)
-            previous_top_level_node = node_name
-            previous_middle_level_node = None
-            previous_low_level_node = None
-        if c in "ⓝⓢ":
-            node_name = "{}_{}".format("N" if c == "ⓝ" else "S", num[c])
-            g.add_node(node_name, type="number" if c == "ⓝ" else "section", level=2, page=num["ⓟ"])
-            g.add_edge(previous_top_level_node, node_name)
-            if previous_middle_level_node:
-                g.add_edge(previous_middle_level_node, node_name)
-            previous_middle_level_node = node_name
-            previous_low_level_node = None
-        if c in "ⓐⓑ":
-            node_name = "{}_{}".format("A" if c == "ⓐ" else "B", num[c])
-            g.add_node(node_name, type="annotation" if c == "ⓐ" else "body", level=1, page=num["ⓟ"])
-            g.add_edge(previous_middle_level_node, node_name)
-            if previous_low_level_node:
-                g.add_edge(previous_low_level_node, node_name)
-            previous_low_level_node = node_name
-    return g
-
-
-def str_to_graph_rimes(str):
-    """
-    Compute graph from string of layout tokens for the RIMES dataset at page level
-    """
-    begin_layout_tokens = "".join(list(RIMES_MATCHING_TOKENS.keys()))
-    layout_token_sequence = keep_only_tokens(str, begin_layout_tokens)
-    g = nx.DiGraph()
-    g.add_node("D", type="document", level=2, page=0)
-    token_name_dict = {
-        "â“‘": "B",
-        "â“ž": "O",
-        "â“¡": "R",
-        "â“¢": "S",
-        "ⓦ": "W",
-        "ⓨ": "Y",
-        "â“Ÿ": "P"
-    }
-    num = dict()
-    previous_node = None
-    for token in begin_layout_tokens:
-        num[token] = 0
-    for ind, c in enumerate(layout_token_sequence):
-        num[c] += 1
-        node_name = "{}_{}".format(token_name_dict[c], num[c])
-        g.add_node(node_name, type=token_name_dict[c], level=1, page=0)
-        g.add_edge("D", node_name)
-        if previous_node:
-            g.add_edge(previous_node, node_name)
-        previous_node = node_name
-    return g
-
-
-def str_to_graph_simara(str):
-    """
-    Compute graph from string of layout tokens for the SIMARA dataset at page level
-    """
-    begin_layout_tokens = "".join(list(SIMARA_MATCHING_TOKENS.keys()))
-    layout_token_sequence = keep_only_tokens(str, begin_layout_tokens)
-    g = nx.DiGraph()
-    g.add_node("D", type="document", level=2, page=0)
-    token_name_dict = {
-        "ⓘ": "I",
-        "â““": "D",
-        "â“¢": "S",
-        "â“’": "C",
-        "â“Ÿ": "P",
-        "ⓐ": "A"
-    }
-    num = dict()
-    previous_node = None
-    for token in begin_layout_tokens:
-        num[token] = 0
-    for ind, c in enumerate(layout_token_sequence):
-        num[c] += 1
-        node_name = "{}_{}".format(token_name_dict[c], num[c])
-        g.add_node(node_name, type=token_name_dict[c], level=1, page=0)
-        g.add_edge("D", node_name)
-        if previous_node:
-            g.add_edge(previous_node, node_name)
-        previous_node = node_name
-    return g
-
-
-def graph_edit_distance_by_page_read(g1, g2):
-    """
-    Compute graph edit distance page by page for the READ 2016 dataset
-    """
-    num_pages_g1 = len([n for n in g1.nodes().items() if n[1]["level"] == 3])
-    num_pages_g2 = len([n for n in g2.nodes().items() if n[1]["level"] == 3])
-    page_graphs_1 = [g1.subgraph([n[0] for n in g1.nodes().items() if n[1]["page"] == num_page]) for num_page in range(1, num_pages_g1+1)]
-    page_graphs_2 = [g2.subgraph([n[0] for n in g2.nodes().items() if n[1]["page"] == num_page]) for num_page in range(1, num_pages_g2+1)]
-    edit = 0
-    for i in range(max(len(page_graphs_1), len(page_graphs_2))):
-        page_1 = page_graphs_1[i] if i < len(page_graphs_1) else nx.DiGraph()
-        page_2 = page_graphs_2[i] if i < len(page_graphs_2) else nx.DiGraph()
-        edit += graph_edit_distance(page_1, page_2)
-    return edit
-
-
-def graph_edit_distance(g1, g2):
-    """
-    Compute graph edit distance between two graphs
-    """
-    for v in nx.optimize_graph_edit_distance(g1, g2,
-                                             node_ins_cost=lambda node: 1,
-                                             node_del_cost=lambda node: 1,
-                                             node_subst_cost=lambda node1, node2: 0 if node1["type"] == node2["type"] else 1,
-                                             edge_ins_cost=lambda edge: 1,
-                                             edge_del_cost=lambda edge: 1,
-                                             edge_subst_cost=lambda edge1, edge2: 0 if edge1 == edge2 else 1
-                                             ):
-        new_edit = v
-    return new_edit
-
-
-def edit_and_num_items_for_ged_from_str_read(str_gt, str_pred):
-    """
-    Compute graph edit distance and num nodes/edges for normalized graph edit distance
-    For the READ 2016 dataset
-    """
-    g_gt = str_to_graph_read(str_gt)
-    g_pred = str_to_graph_read(str_pred)
-    return graph_edit_distance_by_page_read(g_gt, g_pred), g_gt.number_of_nodes() + g_gt.number_of_edges()
-
-
-def edit_and_num_items_for_ged_from_str_rimes(str_gt, str_pred):
-    """
-    Compute graph edit distance and num nodes/edges for normalized graph edit distance
-    For the RIMES dataset
-    """
-    g_gt = str_to_graph_rimes(str_gt)
-    g_pred = str_to_graph_rimes(str_pred)
-    return graph_edit_distance(g_gt, g_pred), g_gt.number_of_nodes() + g_gt.number_of_edges()
-
-
-def edit_and_num_items_for_ged_from_str_simara(str_gt, str_pred):
-    """
-    Compute graph edit distance and num nodes/edges for normalized graph edit distance
-    For the SIMARA dataset
-    """
-    g_gt = str_to_graph_simara(str_gt)
-    g_pred = str_to_graph_simara(str_pred)
-    return graph_edit_distance(g_gt, g_pred), g_gt.number_of_nodes() + g_gt.number_of_edges()
diff --git a/basic/scheduler.py b/basic/scheduler.py
deleted file mode 100644
index 6c875c1da2f57ab0306635cbc77309c2b83af116..0000000000000000000000000000000000000000
--- a/basic/scheduler.py
+++ /dev/null
@@ -1,51 +0,0 @@
-
-from torch.nn import Dropout, Dropout2d
-import numpy as np
-
-
-class DropoutScheduler:
-
-    def __init__(self, models, function, T=1e5):
-        """
-        T: number of gradient updates to converge
-        """
-
-        self.teta_list = list()
-        self.init_teta_list(models)
-        self.function = function
-        self.T = T
-        self.step_num = 0
-
-    def step(self):
-        self.step(1)
-
-    def step(self, num):
-        self.step_num += num
-
-    def init_teta_list(self, models):
-        for model_name in models.keys():
-            self.init_teta_list_module(models[model_name])
-
-    def init_teta_list_module(self, module):
-        for child in module.children():
-            if isinstance(child, Dropout) or isinstance(child, Dropout2d):
-                self.teta_list.append([child, child.p])
-            else:
-                self.init_teta_list_module(child)
-
-    def update_dropout_rate(self):
-        for (module, p) in self.teta_list:
-            module.p = self.function(p, self.step_num, self.T)
-
-
-def exponential_dropout_scheduler(dropout_rate, step, max_step):
-    return dropout_rate * (1 - np.exp(-10 * step / max_step))
-
-
-def exponential_scheduler(init_value, end_value, step, max_step):
-    step = min(step, max_step-1)
-    return init_value - (init_value - end_value) * (1 - np.exp(-10*step/max_step))
-
-
-def linear_scheduler(init_value, end_value, step, max_step):
-    return init_value + step * (end_value - init_value) / max_step
\ No newline at end of file
diff --git a/basic/transforms.py b/basic/transforms.py
deleted file mode 100644
index 18c8084dafe668025ea7ae58ce99384b58a672a1..0000000000000000000000000000000000000000
--- a/basic/transforms.py
+++ /dev/null
@@ -1,438 +0,0 @@
-
-import numpy as np
-from numpy import random
-from PIL import Image, ImageOps
-from cv2 import erode, dilate, normalize
-import cv2
-import math
-from basic.utils import randint, rand_uniform, rand
-from torchvision.transforms import RandomPerspective, RandomCrop, ColorJitter, GaussianBlur, RandomRotation
-from torchvision.transforms.functional import InterpolationMode
-
-"""
-Each transform class defined here takes as input a PIL Image and returns the modified PIL Image
-"""
-
-
-class SignFlipping:
-    """
-    Color inversion
-    """
-
-    def __init__(self):
-        pass
-
-    def __call__(self, x):
-        return ImageOps.invert(x)
-
-
-class DPIAdjusting:
-    """
-    Resolution modification
-    """
-
-    def __init__(self, factor, preserve_ratio):
-        self.factor = factor
-
-    def __call__(self, x):
-        w, h = x.size
-        return x.resize((int(np.ceil(w * self.factor)), int(np.ceil(h * self.factor))), Image.BILINEAR)
-
-
-class Dilation:
-    """
-    OCR: stroke width increasing
-    """
-
-    def __init__(self, kernel, iterations):
-        self.kernel = np.ones(kernel, np.uint8)
-        self.iterations = iterations
-
-    def __call__(self, x):
-        return Image.fromarray(dilate(np.array(x), self.kernel, iterations=self.iterations))
-
-
-class Erosion:
-    """
-    OCR: stroke width decreasing
-    """
-
-    def __init__(self, kernel, iterations):
-        self.kernel = np.ones(kernel, np.uint8)
-        self.iterations = iterations
-
-    def __call__(self, x):
-        return Image.fromarray(erode(np.array(x), self.kernel, iterations=self.iterations))
-
-
-class GaussianNoise:
-    """
-    Add Gaussian Noise
-    """
-
-    def __init__(self, std):
-        self.std = std
-
-    def __call__(self, x):
-        x_np = np.array(x)
-        mean, std = np.mean(x_np), np.std(x_np)
-        std = math.copysign(max(abs(std), 0.000001), std)
-        min_, max_ = np.min(x_np,), np.max(x_np)
-        normal_noise = np.random.randn(*x_np.shape)
-        if len(x_np.shape) == 3 and x_np.shape[2] == 3 and np.all(x_np[:, :, 0] == x_np[:, :, 1]) and np.all(x_np[:, :, 0] == x_np[:, :, 2]):
-            normal_noise[:, :, 1] = normal_noise[:, :, 2] = normal_noise[:, :, 0]
-        x_np = ((x_np-mean)/std + normal_noise*self.std) * std + mean
-        x_np = normalize(x_np, x_np, max_, min_, cv2.NORM_MINMAX)
-
-        return Image.fromarray(x_np.astype(np.uint8))
-
-
-class Sharpen:
-    """
-    Add Gaussian Noise
-    """
-
-    def __init__(self, alpha, strength):
-        self.alpha = alpha
-        self.strength = strength
-
-    def __call__(self, x):
-        x_np = np.array(x)
-        id_matrix = np.array([[0, 0, 0],
-                              [0, 1, 0],
-                              [0, 0, 0]]
-                             )
-        effect_matrix = np.array([[1, 1, 1],
-                                  [1, -(8+self.strength), 1],
-                                  [1, 1, 1]]
-                                 )
-        kernel = (1 - self.alpha) * id_matrix - self.alpha * effect_matrix
-        kernel = np.expand_dims(kernel, axis=2)
-        kernel = np.concatenate([kernel, kernel, kernel], axis=2)
-        sharpened = cv2.filter2D(x_np, -1, kernel=kernel[:, :, 0])
-        return Image.fromarray(sharpened.astype(np.uint8))
-
-
-class ZoomRatio:
-    """
-        Crop by ratio
-        Preserve dimensions if keep_dim = True (= zoom)
-    """
-
-    def __init__(self, ratio_h, ratio_w, keep_dim=True):
-        self.ratio_w = ratio_w
-        self.ratio_h = ratio_h
-        self.keep_dim = keep_dim
-
-    def __call__(self, x):
-        w, h = x.size
-        x = RandomCrop((int(h * self.ratio_h), int(w * self.ratio_w)))(x)
-        if self.keep_dim:
-            x = x.resize((w, h), Image.BILINEAR)
-        return x
-
-
-class ElasticDistortion:
-
-    def __init__(self, kernel_size=(7, 7), sigma=5, alpha=1):
-
-        self.kernel_size = kernel_size
-        self.sigma = sigma
-        self.alpha = alpha
-
-    def __call__(self, x):
-        x_np = np.array(x)
-
-        h, w = x_np.shape[:2]
-
-        dx = np.random.uniform(-1, 1, (h, w))
-        dy = np.random.uniform(-1, 1, (h, w))
-
-        x_gauss = cv2.GaussianBlur(dx, self.kernel_size, self.sigma)
-        y_gauss = cv2.GaussianBlur(dy, self.kernel_size, self.sigma)
-
-        n = np.sqrt(x_gauss**2 + y_gauss**2)
-
-        nd_x = self.alpha * x_gauss / n
-        nd_y = self.alpha * y_gauss / n
-
-        ind_y, ind_x = np.indices((h, w), dtype=np.float32)
-
-        map_x = nd_x + ind_x
-        map_x = map_x.reshape(h, w).astype(np.float32)
-        map_y = nd_y + ind_y
-        map_y = map_y.reshape(h, w).astype(np.float32)
-
-        dst = cv2.remap(x_np, map_x, map_y, cv2.INTER_LINEAR)
-        return Image.fromarray(dst.astype(np.uint8))
-
-
-class Tightening:
-    """
-    Reduce interline spacing
-    """
-
-    def __init__(self, color=255, remove_proba=0.75):
-        self.color = color
-        self.remove_proba = remove_proba
-
-    def __call__(self, x):
-        x_np = np.array(x)
-        interline_indices = [np.all(line == 255) for line in x_np]
-        indices_to_removed = np.logical_and(np.random.choice([True, False], size=len(x_np), replace=True, p=[self.remove_proba, 1-self.remove_proba]), interline_indices)
-        new_x = x_np[np.logical_not(indices_to_removed)]
-        return Image.fromarray(new_x.astype(np.uint8))
-
-
-def get_list_augmenters(img, aug_configs, fill_value):
-    """
-    Randomly select a list of data augmentation techniques to used based on aug_configs
-    """
-    augmenters = list()
-    for aug_config in aug_configs:
-        if rand() > aug_config["proba"]:
-            continue
-        if aug_config["type"] == "dpi":
-            valid_factor = False
-            while not valid_factor:
-                factor = rand_uniform(aug_config["min_factor"], aug_config["max_factor"])
-                valid_factor = not (("max_width" in aug_config and factor*img.size[0] > aug_config["max_width"]) or \
-                               ("max_height" in aug_config and factor * img.size[1] > aug_config["max_height"]) or \
-                               ("min_width" in aug_config and factor*img.size[0] < aug_config["min_width"]) or \
-                               ("min_height" in aug_config and factor * img.size[1] < aug_config["min_height"]))
-            augmenters.append(DPIAdjusting(factor, preserve_ratio=aug_config["preserve_ratio"]))
-
-        elif aug_config["type"] == "zoom_ratio":
-            ratio_h = rand_uniform(aug_config["min_ratio_h"], aug_config["max_ratio_h"])
-            ratio_w = rand_uniform(aug_config["min_ratio_w"], aug_config["max_ratio_w"])
-            augmenters.append(ZoomRatio(ratio_h=ratio_h, ratio_w=ratio_w, keep_dim=aug_config["keep_dim"]))
-
-        elif aug_config["type"] == "perspective":
-            scale = rand_uniform(aug_config["min_factor"], aug_config["max_factor"])
-            augmenters.append(RandomPerspective(distortion_scale=scale, p=1, interpolation=InterpolationMode.BILINEAR, fill=fill_value))
-
-        elif aug_config["type"] == "elastic_distortion":
-            kernel_size = randint(aug_config["min_kernel_size"], aug_config["max_kernel_size"]) // 2 * 2 + 1
-            sigma = rand_uniform(aug_config["min_sigma"], aug_config["max_sigma"])
-            alpha= rand_uniform(aug_config["min_alpha"], aug_config["max_alpha"])
-            augmenters.append(ElasticDistortion(kernel_size=(kernel_size, kernel_size), sigma=sigma, alpha=alpha))
-
-        elif aug_config["type"] == "dilation_erosion":
-            kernel_h = randint(aug_config["min_kernel"], aug_config["max_kernel"] + 1)
-            kernel_w = randint(aug_config["min_kernel"], aug_config["max_kernel"] + 1)
-            if randint(0, 2) == 0:
-                augmenters.append(Erosion((kernel_w, kernel_h), aug_config["iterations"]))
-            else:
-                augmenters.append(Dilation((kernel_w, kernel_h), aug_config["iterations"]))
-
-        elif aug_config["type"] == "color_jittering":
-            augmenters.append(ColorJitter(contrast=aug_config["factor_contrast"],
-                              brightness=aug_config["factor_brightness"],
-                              saturation=aug_config["factor_saturation"],
-                              hue=aug_config["factor_hue"],
-                              ))
-
-        elif aug_config["type"] == "gaussian_blur":
-            max_kernel_h = min(aug_config["max_kernel"], img.size[1])
-            max_kernel_w = min(aug_config["max_kernel"], img.size[0])
-            kernel_h = randint(aug_config["min_kernel"], max_kernel_h + 1) // 2 * 2 + 1
-            kernel_w = randint(aug_config["min_kernel"], max_kernel_w + 1) // 2 * 2 + 1
-            sigma = rand_uniform(aug_config["min_sigma"], aug_config["max_sigma"])
-            augmenters.append(GaussianBlur(kernel_size=(kernel_w, kernel_h), sigma=sigma))
-
-        elif aug_config["type"] == "gaussian_noise":
-            augmenters.append(GaussianNoise(std=aug_config["std"]))
-
-        elif aug_config["type"] == "sharpen":
-            alpha = rand_uniform(aug_config["min_alpha"], aug_config["max_alpha"])
-            strength = rand_uniform(aug_config["min_strength"], aug_config["max_strength"])
-            augmenters.append(Sharpen(alpha=alpha, strength=strength))
-
-        else:
-            print("Error - unknown augmentor: {}".format(aug_config["type"]))
-            exit(-1)
-
-    return augmenters
-
-
-def apply_data_augmentation(img, da_config):
-    """
-    Apply data augmentation strategy on input image
-    """
-    applied_da = list()
-    if da_config["proba"] != 1 and rand() > da_config["proba"]:
-        return img, applied_da
-
-    # Convert to PIL Image
-    img = img[:, :, 0] if img.shape[2] == 1 else img
-    img = Image.fromarray(img)
-
-    fill_value = da_config["fill_value"] if "fill_value" in da_config else 255
-    augmenters = get_list_augmenters(img, da_config["augmentations"], fill_value=fill_value)
-    if da_config["order"] == "random":
-        random.shuffle(augmenters)
-
-    for augmenter in augmenters:
-        img = augmenter(img)
-        applied_da.append(type(augmenter).__name__)
-
-    # convert to numpy array
-    img = np.array(img)
-    img = np.expand_dims(img, axis=2) if len(img.shape) == 2 else img
-    return img, applied_da
-
-
-def apply_transform(img, transform):
-    """
-    Apply data augmentation technique on input image
-    """
-    img = img[:, :, 0] if img.shape[2] == 1 else img
-    img = Image.fromarray(img)
-    img = transform(img)
-    img = np.array(img)
-    return np.expand_dims(img, axis=2) if len(img.shape) == 2 else img
-
-
-def line_aug_config(proba_use_da, p):
-    return {
-        "order": "random",
-        "proba": proba_use_da,
-        "augmentations": [
-            {
-                "type": "dpi",
-                "proba": p,
-                "min_factor": 0.5,
-                "max_factor": 1.5,
-                "preserve_ratio": True,
-            },
-            {
-                "type": "perspective",
-                "proba": p,
-                "min_factor": 0,
-                "max_factor": 0.4,
-            },
-            {
-                "type": "elastic_distortion",
-                "proba": p,
-                "min_alpha": 0.5,
-                "max_alpha": 1,
-                "min_sigma": 1,
-                "max_sigma": 10,
-                "min_kernel_size": 3,
-                "max_kernel_size": 9,
-            },
-            {
-                "type": "dilation_erosion",
-                "proba": p,
-                "min_kernel": 1,
-                "max_kernel": 3,
-                "iterations": 1,
-            },
-            {
-                "type": "color_jittering",
-                "proba": p,
-                "factor_hue": 0.2,
-                "factor_brightness": 0.4,
-                "factor_contrast": 0.4,
-                "factor_saturation": 0.4,
-            },
-            {
-                "type": "gaussian_blur",
-                "proba": p,
-                "min_kernel": 3,
-                "max_kernel": 5,
-                "min_sigma": 3,
-                "max_sigma": 5,
-            },
-            {
-                "type": "gaussian_noise",
-                "proba": p,
-                "std": 0.5,
-            },
-            {
-                "type": "sharpen",
-                "proba": p,
-                "min_alpha": 0,
-                "max_alpha": 1,
-                "min_strength": 0,
-                "max_strength": 1,
-            },
-            {
-                "type": "zoom_ratio",
-                "proba": p,
-                "min_ratio_h": 0.8,
-                "max_ratio_h": 1,
-                "min_ratio_w": 0.99,
-                "max_ratio_w": 1,
-                "keep_dim": True
-            },
-        ]
-    }
-
-
-def aug_config(proba_use_da, p):
-    return {
-        "order": "random",
-        "proba": proba_use_da,
-        "augmentations": [
-            {
-                "type": "dpi",
-                "proba": p,
-                "min_factor": 0.75,
-                "max_factor": 1,
-                "preserve_ratio": True,
-            },
-            {
-                "type": "perspective",
-                "proba": p,
-                "min_factor": 0,
-                "max_factor": 0.4,
-            },
-            {
-                "type": "elastic_distortion",
-                "proba": p,
-                "min_alpha": 0.5,
-                "max_alpha": 1,
-                "min_sigma": 1,
-                "max_sigma": 10,
-                "min_kernel_size": 3,
-                "max_kernel_size": 9,
-            },
-            {
-                "type": "dilation_erosion",
-                "proba": p,
-                "min_kernel": 1,
-                "max_kernel": 3,
-                "iterations": 1,
-            },
-            {
-                "type": "color_jittering",
-                "proba": p,
-                "factor_hue": 0.2,
-                "factor_brightness": 0.4,
-                "factor_contrast": 0.4,
-                "factor_saturation": 0.4,
-            },
-            {
-                "type": "gaussian_blur",
-                "proba": p,
-                "min_kernel": 3,
-                "max_kernel": 5,
-                "min_sigma": 3,
-                "max_sigma": 5,
-            },
-            {
-                "type": "gaussian_noise",
-                "proba": p,
-                "std": 0.5,
-            },
-            {
-                "type": "sharpen",
-                "proba": p,
-                "min_alpha": 0,
-                "max_alpha": 1,
-                "min_strength": 0,
-                "max_strength": 1,
-            },
-        ]
-    }
diff --git a/basic/utils.py b/basic/utils.py
deleted file mode 100644
index ac50e57ab1136a2f2e1d209c39271406316d0479..0000000000000000000000000000000000000000
--- a/basic/utils.py
+++ /dev/null
@@ -1,179 +0,0 @@
-
-
-import numpy as np
-import torch
-from torch.distributions.uniform import Uniform
-import cv2
-
-
-def randint(low, high):
-    """
-    call torch.randint to preserve random among dataloader workers
-    """
-    return int(torch.randint(low, high, (1, )))
-
-
-def rand():
-    """
-    call torch.rand to preserve random among dataloader workers
-    """
-    return float(torch.rand((1, )))
-
-
-def rand_uniform(low, high):
-    """
-    call torch uniform to preserve random among dataloader workers
-    """
-    return float(Uniform(low, high).sample())
-
-
-def pad_sequences_1D(data, padding_value):
-    """
-    Pad data with padding_value to get same length
-    """
-    x_lengths = [len(x) for x in data]
-    longest_x = max(x_lengths)
-    padded_data = np.ones((len(data), longest_x)).astype(np.int32) * padding_value
-    for i, x_len in enumerate(x_lengths):
-        padded_data[i, :x_len] = data[i][:x_len]
-    return padded_data
-
-
-def resize_max(img, max_width=None, max_height=None):
-    if max_width is not None and img.shape[1] > max_width:
-        ratio = max_width / img.shape[1]
-        new_h = int(np.floor(ratio * img.shape[0]))
-        new_w = int(np.floor(ratio * img.shape[1]))
-        img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
-    if max_height is not None and img.shape[0] > max_height:
-        ratio = max_height / img.shape[0]
-        new_h = int(np.floor(ratio * img.shape[0]))
-        new_w = int(np.floor(ratio * img.shape[1]))
-        img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
-    return img
-
-
-def pad_images(data, padding_value, padding_mode="br"):
-    """
-    data: list of numpy array
-    mode: "br"/"tl"/"random" (bottom-right, top-left, random)
-    """
-    x_lengths = [x.shape[0] for x in data]
-    y_lengths = [x.shape[1] for x in data]
-    longest_x = max(x_lengths)
-    longest_y = max(y_lengths)
-    padded_data = np.ones((len(data), longest_x, longest_y, data[0].shape[2])) * padding_value
-    for i, xy_len in enumerate(zip(x_lengths, y_lengths)):
-        x_len, y_len = xy_len
-        if padding_mode == "br":
-            padded_data[i, :x_len, :y_len, ...] = data[i]
-        elif padding_mode == "tl":
-            padded_data[i, -x_len:, -y_len:, ...] = data[i]
-        elif padding_mode == "random":
-            xmax = longest_x - x_len
-            ymax = longest_y - y_len
-            xi = randint(0, xmax) if xmax >= 1 else 0
-            yi = randint(0, ymax) if ymax >= 1 else 0
-            padded_data[i, xi:xi+x_len, yi:yi+y_len, ...] = data[i]
-        else:
-            raise NotImplementedError("Undefined padding mode: {}".format(padding_mode))
-    return padded_data
-
-
-def pad_image(image, padding_value, new_height=None, new_width=None, pad_width=None, pad_height=None, padding_mode="br", return_position=False):
-    """
-    data: list of numpy array
-    mode: "br"/"tl"/"random" (bottom-right, top-left, random)
-    """
-    if pad_width is not None and new_width is not None:
-        raise NotImplementedError("pad_with and new_width are not compatible")
-    if pad_height is not None and new_height is not None:
-        raise NotImplementedError("pad_height and new_height are not compatible")
-
-    h, w, c = image.shape
-    pad_width = pad_width if pad_width is not None else max(0, new_width - w) if new_width is not None else 0
-    pad_height = pad_height if pad_height is not None else max(0, new_height - h) if new_height is not None else 0
-
-    if not (pad_width == 0 and pad_height == 0):
-        padded_image = np.ones((h+pad_height, w+pad_width, c)) * padding_value
-        if padding_mode == "br":
-            hi, wi = 0, 0
-        elif padding_mode == "tl":
-            hi, wi = pad_height, pad_width
-        elif padding_mode == "random":
-            hi = randint(0, pad_height) if pad_height >= 1 else 0
-            wi = randint(0, pad_width) if pad_width >= 1 else 0
-        else:
-            raise NotImplementedError("Undefined padding mode: {}".format(padding_mode))
-        padded_image[hi:hi + h, wi:wi + w, ...] = image
-        output = padded_image
-    else:
-        hi, wi = 0, 0
-        output = image
-
-    if return_position:
-        return output, [[hi, hi+h], [wi, wi+w]]
-    return output
-
-
-def pad_image_width_right(img, new_width, padding_value):
-    """
-    Pad img to right side with padding value to reach new_width as width
-    """
-    h, w, c = img.shape
-    pad_width = max((new_width - w), 0)
-    pad_right = np.ones((h, pad_width, c), dtype=img.dtype) * padding_value
-    img = np.concatenate([img, pad_right], axis=1)
-    return img
-
-
-def pad_image_width_left(img, new_width, padding_value):
-    """
-    Pad img to left side with padding value to reach new_width as width
-    """
-    h, w, c = img.shape
-    pad_width = max((new_width - w), 0)
-    pad_left = np.ones((h, pad_width, c), dtype=img.dtype) * padding_value
-    img = np.concatenate([pad_left, img], axis=1)
-    return img
-
-
-def pad_image_width_random(img, new_width, padding_value, max_pad_left_ratio=1):
-    """
-    Randomly pad img to left and right sides with padding value to reach new_width as width
-    """
-    h, w, c = img.shape
-    pad_width = max((new_width - w), 0)
-    max_pad_left = int(max_pad_left_ratio*pad_width)
-    pad_left = randint(0, min(pad_width, max_pad_left)) if pad_width != 0 and max_pad_left > 0 else 0
-    pad_right = pad_width - pad_left
-    pad_left = np.ones((h, pad_left, c), dtype=img.dtype) * padding_value
-    pad_right = np.ones((h, pad_right, c), dtype=img.dtype) * padding_value
-    img = np.concatenate([pad_left, img, pad_right], axis=1)
-    return img
-
-
-def pad_image_height_random(img, new_height, padding_value, max_pad_top_ratio=1):
-    """
-    Randomly pad img top and bottom sides with padding value to reach new_width as width
-    """
-    h, w, c = img.shape
-    pad_height = max((new_height - h), 0)
-    max_pad_top = int(max_pad_top_ratio*pad_height)
-    pad_top = randint(0, min(pad_height, max_pad_top)) if pad_height != 0 and max_pad_top > 0 else 0
-    pad_bottom = pad_height - pad_top
-    pad_top = np.ones((pad_top, w, c), dtype=img.dtype) * padding_value
-    pad_bottom = np.ones((pad_bottom, w, c), dtype=img.dtype) * padding_value
-    img = np.concatenate([pad_top, img, pad_bottom], axis=0)
-    return img
-
-
-def pad_image_height_bottom(img, new_height, padding_value):
-    """
-    Pad img to bottom side with padding value to reach new_height as height
-    """
-    h, w, c = img.shape
-    pad_height = max((new_height - h), 0)
-    pad_bottom = np.ones((pad_height, w, c)) * padding_value
-    img = np.concatenate([img, pad_bottom], axis=0)
-    return img
diff --git a/prediction-requirements.txt b/prediction-requirements.txt
index 1fe32d37816bfb6c2456ebe2bb7b9858ffef49e9..86bfaeb8bf1dcef9ad52f4c64507482f5d942024 100644
--- a/prediction-requirements.txt
+++ b/prediction-requirements.txt
@@ -1,4 +1,8 @@
-numpy==1.22.3
-opencv-python==4.5.5.64
-PyYAML==6.0
-torch==1.11.0
+arkindex-client==1.0.11
+editdistance==0.6.0
+fontTools==4.29.1
+imageio==2.16.0
+networkx==2.6.3
+tensorboard==0.2.1
+torchvision==0.12.0
+tqdm==4.62.3
diff --git a/requirements.txt b/requirements.txt
index 06c4835601286c1154ad24641bf4e102babf8fe1..1fe32d37816bfb6c2456ebe2bb7b9858ffef49e9 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,81 +1,4 @@
-absl-py==1.0.0
-backports.cached-property==1.0.1
-cachetools==5.0.0
-certifi==2021.10.8
-charset-normalizer==2.0.12
-click==8.0.4
-click-default-group==1.2.2
-cloup==0.7.1
-colorama==0.4.4
-colour==0.1.5
-commonmark==0.9.1
-cycler==0.11.0
-decorator==5.1.1
-EasyProcess==1.1
-editdistance==0.6.0
-entrypoint2==1.0
-fonttools==4.29.1
-glcontext==2.3.4
-google-auth==2.6.0
-google-auth-oauthlib==0.4.6
-grpcio==1.44.0
-idna==3.3
-imageio==2.16.0
-importlib-metadata==4.11.1
-isosurfaces==0.1.0
-joblib==1.1.0
-kiwisolver==1.3.2
-manim==0.15.0
-manimlib==0.2.0
-ManimPango==0.4.0.post2
-mapbox-earcut==0.12.11
-Markdown==3.3.6
-masked-norm==0.0.0
-matplotlib==3.5.1
-moderngl==5.6.4
-moderngl-window==2.4.1
-multipledispatch==0.6.0
-networkx==2.6.3
-numpy==1.22.2
-oauthlib==3.2.0
-opencv-python==4.5.5.62
-packaging==21.3
-Pillow==9.0.1
-pkg_resources==0.0.0
-progressbar==2.5
-protobuf==3.19.4
-pyasn1==0.4.8
-pyasn1-modules==0.2.8
-pycairo==1.20.1
-pydub==0.25.1
-pyglet==1.5.21
-Pygments==2.11.2
-pyparsing==3.0.7
-pyrr==0.10.3
-python-dateutil==2.8.2
-pyunpack==0.2.2
-PyWavelets==1.2.0
-requests==2.27.1
-requests-oauthlib==1.3.1
-rich==11.2.0
-rsa==4.8
-scikit-image==0.19.2
-scikit-learn==1.0.2
-scipy==1.8.0
-screeninfo==0.8
-six==1.16.0
-skia-pathops==0.7.2
-srt==3.5.1
-tensorboard==2.8.0
-tensorboard-data-server==0.6.1
-tensorboard-plugin-wit==1.8.1
-threadpoolctl==3.1.0
-tifffile==2022.2.9
-torch==1.8.1
-torchvision==0.9.1
-tqdm==4.62.3
-typing_extensions==4.1.1
-urllib3==1.26.8
-watchdog==2.1.6
-Werkzeug==2.0.3
-zipp==3.7.0
+numpy==1.22.3
+opencv-python==4.5.5.64
+PyYAML==6.0
+torch==1.11.0
diff --git a/setup.py b/setup.py
index bb7648588ce4b0c76de9cc0ed2351a18c9a91663..a0002964d8cb9cbb10e2a5ec8a3b3b62cfcc384b 100755
--- a/setup.py
+++ b/setup.py
@@ -6,8 +6,8 @@ from pathlib import Path
 from setuptools import find_packages, setup
 
 
-def parse_requirements():
-    path = Path(__file__).parent.resolve() / "prediction-requirements.txt"
+def parse_requirements(filename):
+    path = Path(__file__).parent.resolve() / filename
     assert path.exists(), f"Missing requirements: {path}"
     return list(map(str.strip, path.read_text().splitlines()))
 
@@ -21,6 +21,12 @@ setup(
     author="Teklia",
     author_email="contact@teklia.com",
     url="https://gitlab.com/teklia/dan",
-    install_requires=parse_requirements(),
+    install_requires=parse_requirements("requirements.txt"),
     packages=find_packages(),
+    entry_points={
+        "console_scripts": [
+            "teklia-dan=dan.cli:main",
+        ]
+    },
+    extras_require={"predict": parse_requirements("prediction-requirements.txt")},
 )
diff --git a/visual.png b/visual.png
deleted file mode 100644
index e228e7924cb2760dc61636651b525ef5359408c5..0000000000000000000000000000000000000000
Binary files a/visual.png and /dev/null differ
diff --git a/visual_slanted_lines.png b/visual_slanted_lines.png
deleted file mode 100644
index 478af3ac1eb40316bbb1f0e7400f592b0d31ec70..0000000000000000000000000000000000000000
Binary files a/visual_slanted_lines.png and /dev/null differ