diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py
index 9e9b52a17f700eff39d5f90c30942206a7a4af44..4afb776c9389ed7c9c7e6003a1a7e3068d5187ce 100644
--- a/dan/manager/dataset.py
+++ b/dan/manager/dataset.py
@@ -1,239 +1,90 @@
 # -*- coding: utf-8 -*-
+import copy
 import json
 import os
-import random
 
 import numpy as np
 import torch
-from torch.utils.data import DataLoader, Dataset
-from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data import Dataset
 from torchvision.io import ImageReadMode, read_image
 
 from dan.datasets.utils import natural_sort
-from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms
+from dan.utils import token_to_ind
 
 
-class DatasetManager:
-    def __init__(self, params, device: str):
-        self.params = params
-        self.dataset_class = None
-
-        self.my_collate_function = None
-        # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
-        self.pin_memory = device != "cpu"
-
-        self.train_dataset = None
-        self.valid_datasets = dict()
-        self.test_datasets = dict()
+class OCRDataset(Dataset):
+    """
+    Dataset class to handle dataset loading
+    """
 
-        self.train_loader = None
-        self.valid_loaders = dict()
-        self.test_loaders = dict()
+    def __init__(
+        self,
+        set_name,
+        paths_and_sets,
+        charset,
+        tokens,
+        preprocessing_transforms,
+        augmentation_transforms,
+        load_in_memory=False,
+        mean=None,
+        std=None,
+    ):
+        self.set_name = set_name
+        self.charset = charset
+        self.tokens = tokens
+        self.load_in_memory = load_in_memory
+        self.mean = mean
+        self.std = std
 
-        self.train_sampler = None
-        self.valid_samplers = dict()
-        self.test_samplers = dict()
+        # Pre-processing, augmentation
+        self.preprocessing_transforms = preprocessing_transforms
+        self.augmentation_transforms = augmentation_transforms
 
-        self.generator = torch.Generator()
-        self.generator.manual_seed(0)
+        # Factor to reduce the height and width of the feature vector before feeding the decoder.
+        self.reduce_dims_factor = np.array([32, 8, 1])
 
-        self.batch_size = {
-            "train": self.params["batch_size"],
-            "val": self.params["valid_batch_size"]
-            if "valid_batch_size" in self.params
-            else self.params["batch_size"],
-            "test": self.params["test_batch_size"]
-            if "test_batch_size" in self.params
-            else 1,
-        }
+        # Load samples and preprocess images if load_in_memory is True
+        self.samples = self.load_samples(paths_and_sets)
 
-    def apply_specific_treatment_after_dataset_loading(self, dataset):
-        raise NotImplementedError
+        # Curriculum config
+        self.curriculum_config = None
 
-    def load_datasets(self):
+    def __len__(self):
         """
-        Load training and validation datasets
+        Return the dataset size
         """
-        self.train_dataset = self.dataset_class(
-            self.params,
-            "train",
-            self.params["train"]["name"],
-            self.get_paths_and_sets(self.params["train"]["datasets"]),
-            augmentation_transforms=(
-                get_augmentation_transforms()
-                if self.params["config"]["augmentation"]
-                else None
-            ),
-        )
-
-        (
-            self.params["config"]["mean"],
-            self.params["config"]["std"],
-        ) = self.train_dataset.compute_std_mean()
-
-        self.my_collate_function = self.train_dataset.collate_function(
-            self.params["config"]
-        )
-        self.apply_specific_treatment_after_dataset_loading(self.train_dataset)
-
-        for custom_name in self.params["val"].keys():
-            self.valid_datasets[custom_name] = self.dataset_class(
-                self.params,
-                "val",
-                custom_name,
-                self.get_paths_and_sets(self.params["val"][custom_name]),
-                augmentation_transforms=None,
-            )
-            self.apply_specific_treatment_after_dataset_loading(
-                self.valid_datasets[custom_name]
-            )
+        return len(self.samples)
 
-    def load_ddp_samplers(self):
+    def __getitem__(self, idx):
         """
-        Load training and validation data samplers
+        Return an item from the dataset (image and label)
         """
-        if self.params["use_ddp"]:
-            self.train_sampler = DistributedSampler(
-                self.train_dataset,
-                num_replicas=self.params["num_gpu"],
-                rank=self.params["ddp_rank"],
-                shuffle=True,
-            )
-            for custom_name in self.valid_datasets.keys():
-                self.valid_samplers[custom_name] = DistributedSampler(
-                    self.valid_datasets[custom_name],
-                    num_replicas=self.params["num_gpu"],
-                    rank=self.params["ddp_rank"],
-                    shuffle=False,
-                )
-        else:
-            for custom_name in self.valid_datasets.keys():
-                self.valid_samplers[custom_name] = None
+        # Load preprocessed image
+        sample = copy.deepcopy(self.samples[idx])
+        if not self.load_in_memory:
+            sample["img"] = self.get_sample_img(idx)
 
-    def load_dataloaders(self):
-        """
-        Load training and validation data loaders
-        """
-        self.train_loader = DataLoader(
-            self.train_dataset,
-            batch_size=self.batch_size["train"],
-            shuffle=True if self.train_sampler is None else False,
-            drop_last=False,
-            batch_sampler=self.train_sampler,
-            sampler=self.train_sampler,
-            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-            pin_memory=self.pin_memory,
-            collate_fn=self.my_collate_function,
-            worker_init_fn=self.seed_worker,
-            generator=self.generator,
-        )
+        # Convert to numpy
+        sample["img"] = np.array(sample["img"])
 
-        for key in self.valid_datasets.keys():
-            self.valid_loaders[key] = DataLoader(
-                self.valid_datasets[key],
-                batch_size=self.batch_size["val"],
-                sampler=self.valid_samplers[key],
-                batch_sampler=self.valid_samplers[key],
-                shuffle=False,
-                num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-                pin_memory=self.pin_memory,
-                drop_last=False,
-                collate_fn=self.my_collate_function,
-                worker_init_fn=self.seed_worker,
-                generator=self.generator,
-            )
+        # Apply data augmentation
+        if self.augmentation_transforms:
+            sample["img"] = self.augmentation_transforms(image=sample["img"])["image"]
 
-    @staticmethod
-    def seed_worker(worker_id):
-        worker_seed = torch.initial_seed() % 2**32
-        np.random.seed(worker_seed)
-        random.seed(worker_seed)
+        # Image normalization
+        sample["img"] = (sample["img"] - self.mean) / self.std
 
-    def generate_test_loader(self, custom_name, sets_list):
-        """
-        Load test dataset, data sampler and data loader
-        """
-        if custom_name in self.test_loaders.keys():
-            return
-        paths_and_sets = list()
-        for set_info in sets_list:
-            paths_and_sets.append(
-                {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
-            )
-        self.test_datasets[custom_name] = self.dataset_class(
-            self.params,
-            "test",
-            custom_name,
-            paths_and_sets,
-        )
-        self.apply_specific_treatment_after_dataset_loading(
-            self.test_datasets[custom_name]
-        )
-        if self.params["use_ddp"]:
-            self.test_samplers[custom_name] = DistributedSampler(
-                self.test_datasets[custom_name],
-                num_replicas=self.params["num_gpu"],
-                rank=self.params["ddp_rank"],
-                shuffle=False,
-            )
-        else:
-            self.test_samplers[custom_name] = None
-        self.test_loaders[custom_name] = DataLoader(
-            self.test_datasets[custom_name],
-            batch_size=self.batch_size["test"],
-            sampler=self.test_samplers[custom_name],
-            shuffle=False,
-            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-            pin_memory=self.pin_memory,
-            drop_last=False,
-            collate_fn=self.my_collate_function,
-            worker_init_fn=self.seed_worker,
-            generator=self.generator,
+        # Get final height and width
+        sample["img_reduced_shape"], sample["img_position"] = self.compute_final_size(
+            sample["img"]
         )
 
-    def get_paths_and_sets(self, dataset_names_folds):
-        paths_and_sets = list()
-        for dataset_name, fold in dataset_names_folds:
-            path = self.params["datasets"][dataset_name]
-            paths_and_sets.append({"path": path, "set_name": fold})
-        return paths_and_sets
-
-
-class GenericDataset(Dataset):
-    """
-    Main class to handle dataset loading
-    """
-
-    def __init__(self, params, set_name, custom_name, paths_and_sets):
-        self.params = params
-        self.name = custom_name
-        self.set_name = set_name
-        self.mean = (
-            np.array(params["config"]["mean"])
-            if "mean" in params["config"].keys()
-            else None
-        )
-        self.std = (
-            np.array(params["config"]["std"])
-            if "std" in params["config"].keys()
-            else None
-        )
-        self.preprocessing_transforms = get_preprocessing_transforms(
-            params["config"]["preprocessings"]
-        )
-        self.load_in_memory = (
-            self.params["config"]["load_in_memory"]
-            if "load_in_memory" in self.params["config"]
-            else True
+        # Convert label into tokens
+        sample["token_label"], sample["label_len"] = self.convert_sample_label(
+            sample["label"]
         )
 
-        # Load samples and preprocess images if load_in_memory is True
-        self.samples = self.load_samples(paths_and_sets)
-
-        self.curriculum_config = None
-
-    def __len__(self):
-        return len(self.samples)
+        return sample
 
     @staticmethod
     def load_image(path):
@@ -273,6 +124,17 @@ class GenericDataset(Dataset):
                     )
         return samples
 
+    def get_sample_img(self, i):
+        """
+        Get image by index
+        """
+        if self.load_in_memory:
+            return self.samples[i]["img"]
+        else:
+            return self.preprocessing_transforms(
+                self.load_image(self.samples[i]["path"])
+            )
+
     def compute_std_mean(self):
         """
         Compute cumulated variance and mean of whole dataset
@@ -299,13 +161,27 @@ class GenericDataset(Dataset):
                 self.std = np.sqrt(diff / nb_pixels)
         return self.mean, self.std
 
-    def get_sample_img(self, i):
+    def compute_final_size(self, img):
         """
-        Get image by index
+        Compute the final image size and position after feature extraction
         """
-        if self.load_in_memory:
-            return self.samples[i]["img"]
-        else:
-            return self.preprocessing_transforms(
-                self.load_image(self.samples[i]["path"])
-            )
+        image_reduced_shape = np.ceil(img.shape / self.reduce_dims_factor).astype(int)
+
+        if self.set_name == "train":
+            image_reduced_shape = [max(1, t) for t in image_reduced_shape]
+
+        image_position = [
+            [0, img.shape[0]],
+            [0, img.shape[1]],
+        ]
+        return image_reduced_shape, image_position
+
+    def convert_sample_label(self, label):
+        """
+        Tokenize the label and return its length
+        """
+        token_label = token_to_ind(self.charset, label)
+        token_label.append(self.tokens["end"])
+        label_len = len(token_label)
+        token_label.insert(0, self.tokens["start"])
+        return token_label, label_len
diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index 6e32c586423f67b23a0f7b6f0af0113061bd77f3..9425632335798b5eb7933d8462bd5da7e66b6158 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -1,37 +1,229 @@
 # -*- coding: utf-8 -*-
-import copy
 import os
 import pickle
+import random
 
 import numpy as np
 import torch
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
 
-from dan.manager.dataset import DatasetManager, GenericDataset
-from dan.utils import pad_images, pad_sequences_1D, token_to_ind
+from dan.manager.dataset import OCRDataset
+from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms
+from dan.utils import pad_images, pad_sequences_1D
 
 
-class OCRDatasetManager(DatasetManager):
-    """
-    Specific class to handle OCR/HTR tasks
-    """
-
+class OCRDatasetManager:
     def __init__(self, params, device: str):
-        super(OCRDatasetManager, self).__init__(params, device)
+        self.params = params
+
+        # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
+        self.pin_memory = device != "cpu"
+
+        self.train_dataset = None
+        self.valid_datasets = dict()
+        self.test_datasets = dict()
+
+        self.train_loader = None
+        self.valid_loaders = dict()
+        self.test_loaders = dict()
+
+        self.train_sampler = None
+        self.valid_samplers = dict()
+        self.test_samplers = dict()
 
-        self.dataset_class = OCRDataset
-        self.charset = (
-            params["charset"] if "charset" in params else self.get_merged_charsets()
+        self.mean = (
+            np.array(params["config"]["mean"])
+            if "mean" in params["config"].keys()
+            else None
         )
+        self.std = (
+            np.array(params["config"]["std"])
+            if "std" in params["config"].keys()
+            else None
+        )
+
+        self.generator = torch.Generator()
+        self.generator.manual_seed(0)
 
-        self.tokens = {"pad": len(self.charset) + 2}
-        self.tokens["end"] = len(self.charset)
-        self.tokens["start"] = len(self.charset) + 1
+        self.batch_size = self.get_batch_size_by_set()
+
+        self.load_in_memory = (
+            self.params["config"]["load_in_memory"]
+            if "load_in_memory" in self.params["config"]
+            else True
+        )
+        self.charset = self.get_charset()
+        self.tokens = self.get_tokens()
         self.params["config"]["padding_token"] = self.tokens["pad"]
 
-    def get_merged_charsets(self):
+        self.my_collate_function = OCRCollateFunction(self.params["config"])
+        self.augmentation = (
+            get_augmentation_transforms()
+            if self.params["config"]["augmentation"]
+            else None
+        )
+        self.preprocessing = get_preprocessing_transforms(
+            params["config"]["preprocessings"]
+        )
+
+    def load_datasets(self):
+        """
+        Load training and validation datasets
+        """
+        self.train_dataset = OCRDataset(
+            set_name="train",
+            paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]),
+            charset=self.charset,
+            tokens=self.tokens,
+            preprocessing_transforms=self.preprocessing,
+            augmentation_transforms=self.augmentation,
+            load_in_memory=self.load_in_memory,
+            mean=self.mean,
+            std=self.std,
+        )
+
+        self.mean, self.std = self.train_dataset.compute_std_mean()
+
+        for custom_name in self.params["val"].keys():
+            self.valid_datasets[custom_name] = OCRDataset(
+                set_name="val",
+                paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]),
+                charset=self.charset,
+                tokens=self.tokens,
+                preprocessing_transforms=self.preprocessing,
+                augmentation_transforms=None,
+                load_in_memory=self.load_in_memory,
+                mean=self.mean,
+                std=self.std,
+            )
+
+    def load_ddp_samplers(self):
+        """
+        Load training and validation data samplers
+        """
+        if self.params["use_ddp"]:
+            self.train_sampler = DistributedSampler(
+                self.train_dataset,
+                num_replicas=self.params["num_gpu"],
+                rank=self.params["ddp_rank"],
+                shuffle=True,
+            )
+            for custom_name in self.valid_datasets.keys():
+                self.valid_samplers[custom_name] = DistributedSampler(
+                    self.valid_datasets[custom_name],
+                    num_replicas=self.params["num_gpu"],
+                    rank=self.params["ddp_rank"],
+                    shuffle=False,
+                )
+        else:
+            for custom_name in self.valid_datasets.keys():
+                self.valid_samplers[custom_name] = None
+
+    def load_dataloaders(self):
+        """
+        Load training and validation data loaders
+        """
+        self.train_loader = DataLoader(
+            self.train_dataset,
+            batch_size=self.batch_size["train"],
+            shuffle=True if self.train_sampler is None else False,
+            drop_last=False,
+            batch_sampler=self.train_sampler,
+            sampler=self.train_sampler,
+            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
+            pin_memory=self.pin_memory,
+            collate_fn=self.my_collate_function,
+            worker_init_fn=self.seed_worker,
+            generator=self.generator,
+        )
+
+        for key in self.valid_datasets.keys():
+            self.valid_loaders[key] = DataLoader(
+                self.valid_datasets[key],
+                batch_size=self.batch_size["val"],
+                sampler=self.valid_samplers[key],
+                batch_sampler=self.valid_samplers[key],
+                shuffle=False,
+                num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
+                pin_memory=self.pin_memory,
+                drop_last=False,
+                collate_fn=self.my_collate_function,
+                worker_init_fn=self.seed_worker,
+                generator=self.generator,
+            )
+
+    @staticmethod
+    def seed_worker(worker_id):
+        """
+        Set worker seed
+        """
+        worker_seed = torch.initial_seed() % 2**32
+        np.random.seed(worker_seed)
+        random.seed(worker_seed)
+
+    def generate_test_loader(self, custom_name, sets_list):
+        """
+        Load test dataset, data sampler and data loader
+        """
+        if custom_name in self.test_loaders.keys():
+            return
+        paths_and_sets = list()
+        for set_info in sets_list:
+            paths_and_sets.append(
+                {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
+            )
+        self.test_datasets[custom_name] = OCRDataset(
+            set_name="test",
+            paths_and_sets=paths_and_sets,
+            charset=self.charset,
+            tokens=self.tokens,
+            preprocessing_transforms=self.preprocessing,
+            augmentation_transforms=None,
+            load_in_memory=self.load_in_memory,
+            mean=self.mean,
+            std=self.std,
+        )
+
+        if self.params["use_ddp"]:
+            self.test_samplers[custom_name] = DistributedSampler(
+                self.test_datasets[custom_name],
+                num_replicas=self.params["num_gpu"],
+                rank=self.params["ddp_rank"],
+                shuffle=False,
+            )
+        else:
+            self.test_samplers[custom_name] = None
+
+        self.test_loaders[custom_name] = DataLoader(
+            self.test_datasets[custom_name],
+            batch_size=self.batch_size["test"],
+            sampler=self.test_samplers[custom_name],
+            shuffle=False,
+            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
+            pin_memory=self.pin_memory,
+            drop_last=False,
+            collate_fn=self.my_collate_function,
+            worker_init_fn=self.seed_worker,
+            generator=self.generator,
+        )
+
+    def get_paths_and_sets(self, dataset_names_folds):
+        """
+        Set the right path for each data set
+        """
+        paths_and_sets = list()
+        for dataset_name, fold in dataset_names_folds:
+            path = self.params["datasets"][dataset_name]
+            paths_and_sets.append({"path": path, "set_name": fold})
+        return paths_and_sets
+
+    def get_charset(self):
         """
         Merge the charset of the different datasets used
         """
+        if "charset" in self.params:
+            return self.params["charset"]
         datasets = self.params["datasets"]
         charset = set()
         for key in datasets.keys():
@@ -41,81 +233,29 @@ class OCRDatasetManager(DatasetManager):
             charset.remove("")
         return sorted(list(charset))
 
-    def apply_specific_treatment_after_dataset_loading(self, dataset):
-        dataset.charset = self.charset
-        dataset.tokens = self.tokens
-        dataset.convert_labels()
-
-
-class OCRDataset(GenericDataset):
-    """
-    Specific class to handle OCR/HTR datasets
-    """
-
-    def __init__(
-        self,
-        params,
-        set_name,
-        custom_name,
-        paths_and_sets,
-        augmentation_transforms=None,
-    ):
-        super(OCRDataset, self).__init__(params, set_name, custom_name, paths_and_sets)
-        self.charset = None
-        self.tokens = None
-        # Factor to reduce the height and width of the feature vector before feeding the decoder.
-        self.reduce_dims_factor = np.array([32, 8, 1])
-        self.collate_function = OCRCollateFunction
-        self.augmentation_transforms = augmentation_transforms
-
-    def __getitem__(self, idx):
-        sample = copy.deepcopy(self.samples[idx])
-
-        if not self.load_in_memory:
-            sample["img"] = self.get_sample_img(idx)
-
-        # Convert to numpy
-        sample["img"] = np.array(sample["img"])
-
-        # Data augmentation
-        if self.augmentation_transforms:
-            sample["img"] = self.augmentation_transforms(image=sample["img"])["image"]
-
-        # Normalization
-        sample["img"] = (sample["img"] - self.mean) / self.std
-
-        sample["img_reduced_shape"] = np.ceil(
-            sample["img"].shape / self.reduce_dims_factor
-        ).astype(int)
-
-        if self.set_name == "train":
-            sample["img_reduced_shape"] = [
-                max(1, t) for t in sample["img_reduced_shape"]
-            ]
-
-        sample["img_position"] = [
-            [0, sample["img"].shape[0]],
-            [0, sample["img"].shape[1]],
-        ]
-
-        return sample
-
-    def convert_labels(self):
+    def get_tokens(self):
         """
-        Label str to token at character level
+        Get special tokens
         """
-        for i in range(len(self.samples)):
-            self.samples[i] = self.convert_sample_labels(self.samples[i])
-
-    def convert_sample_labels(self, sample):
-        label = sample["label"]
+        return {
+            "end": len(self.charset),
+            "start": len(self.charset) + 1,
+            "pad": len(self.charset) + 2,
+        }
 
-        sample["label"] = label
-        sample["token_label"] = token_to_ind(self.charset, label)
-        sample["token_label"].append(self.tokens["end"])
-        sample["label_len"] = len(sample["token_label"])
-        sample["token_label"].insert(0, self.tokens["start"])
-        return sample
+    def get_batch_size_by_set(self):
+        """
+        Return batch size for each set
+        """
+        return {
+            "train": self.params["batch_size"],
+            "val": self.params["valid_batch_size"]
+            if "valid_batch_size" in self.params
+            else self.params["batch_size"],
+            "test": self.params["test_batch_size"]
+            if "test_batch_size" in self.params
+            else 1,
+        }
 
 
 class OCRCollateFunction: