diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py
index 0947d6c439eabbc65ded6e0cc7c952b34535bf1e..5f6ad28b9d5ca2556820d9cd1b59a034b6347702 100644
--- a/dan/manager/dataset.py
+++ b/dan/manager/dataset.py
@@ -1,232 +1,86 @@
 # -*- coding: utf-8 -*-
 import json
 import os
-import random
 
 import numpy as np
 import torch
-from torch.utils.data import DataLoader, Dataset
-from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data import Dataset
 from torchvision.io import ImageReadMode, read_image
 
 from dan.datasets.utils import natural_sort
-from dan.transforms import (
-    get_augmentation_transforms,
-    get_normalization_transforms,
-    get_preprocessing_transforms,
-)
+from dan.utils import token_to_ind
 
 
-class DatasetManager:
-    def __init__(self, params, device: str):
-        self.params = params
-        self.dataset_class = None
-
-        self.my_collate_function = None
-        # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
-        self.pin_memory = device != "cpu"
-
-        self.train_dataset = None
-        self.valid_datasets = dict()
-        self.test_datasets = dict()
+class OCRDataset(Dataset):
+    """
+    Dataset class to handle dataset loading
+    """
 
-        self.train_loader = None
-        self.valid_loaders = dict()
-        self.test_loaders = dict()
+    def __init__(
+        self,
+        set_name,
+        paths_and_sets,
+        charset,
+        tokens,
+        preprocessing_transforms,
+        normalization_transforms,
+        augmentation_transforms,
+        load_in_memory=False,
+    ):
+        self.set_name = set_name
+        self.charset = charset
+        self.tokens = tokens
+        self.load_in_memory = load_in_memory
 
-        self.train_sampler = None
-        self.valid_samplers = dict()
-        self.test_samplers = dict()
+        # Pre-processing, augmentation, normalization
+        self.preprocessing_transforms = preprocessing_transforms
+        self.normalization_transforms = normalization_transforms
+        self.augmentation_transforms = augmentation_transforms
 
-        self.generator = torch.Generator()
-        self.generator.manual_seed(0)
+        # Factor to reduce the height and width of the feature vector before feeding the decoder.
+        self.reduce_dims_factor = np.array([32, 8, 1])
 
-        self.batch_size = {
-            "train": self.params["batch_size"],
-            "val": self.params["valid_batch_size"]
-            if "valid_batch_size" in self.params
-            else self.params["batch_size"],
-            "test": self.params["test_batch_size"]
-            if "test_batch_size" in self.params
-            else 1,
-        }
+        # Load samples and preprocess images if load_in_memory is True
+        self.samples = self.load_samples(paths_and_sets)
 
-    def apply_specific_treatment_after_dataset_loading(self, dataset):
-        raise NotImplementedError
+        # Curriculum config
+        self.curriculum_config = None
 
-    def load_datasets(self):
+    def __len__(self):
         """
-        Load training and validation datasets
+        Return the dataset size
         """
-        self.train_dataset = self.dataset_class(
-            self.params,
-            "train",
-            self.params["train"]["name"],
-            self.get_paths_and_sets(self.params["train"]["datasets"]),
-            normalization_transforms=get_normalization_transforms(),
-            augmentation_transforms=(
-                get_augmentation_transforms()
-                if self.params["config"]["augmentation"]
-                else None
-            ),
-        )
-
-        self.my_collate_function = self.train_dataset.collate_function(
-            self.params["config"]
-        )
-        self.apply_specific_treatment_after_dataset_loading(self.train_dataset)
-
-        for custom_name in self.params["val"].keys():
-            self.valid_datasets[custom_name] = self.dataset_class(
-                self.params,
-                "val",
-                custom_name,
-                self.get_paths_and_sets(self.params["val"][custom_name]),
-                normalization_transforms=get_normalization_transforms(),
-                augmentation_transforms=None,
-            )
-            self.apply_specific_treatment_after_dataset_loading(
-                self.valid_datasets[custom_name]
-            )
+        return len(self.samples)
 
-    def load_ddp_samplers(self):
+    def __getitem__(self, idx):
         """
-        Load training and validation data samplers
+        Return an item from the dataset (image and label)
         """
-        if self.params["use_ddp"]:
-            self.train_sampler = DistributedSampler(
-                self.train_dataset,
-                num_replicas=self.params["num_gpu"],
-                rank=self.params["ddp_rank"],
-                shuffle=True,
-            )
-            for custom_name in self.valid_datasets.keys():
-                self.valid_samplers[custom_name] = DistributedSampler(
-                    self.valid_datasets[custom_name],
-                    num_replicas=self.params["num_gpu"],
-                    rank=self.params["ddp_rank"],
-                    shuffle=False,
-                )
-        else:
-            for custom_name in self.valid_datasets.keys():
-                self.valid_samplers[custom_name] = None
 
-    def load_dataloaders(self):
-        """
-        Load training and validation data loaders
-        """
-        self.train_loader = DataLoader(
-            self.train_dataset,
-            batch_size=self.batch_size["train"],
-            shuffle=True if self.train_sampler is None else False,
-            drop_last=False,
-            batch_sampler=self.train_sampler,
-            sampler=self.train_sampler,
-            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-            pin_memory=self.pin_memory,
-            collate_fn=self.my_collate_function,
-            worker_init_fn=self.seed_worker,
-            generator=self.generator,
-        )
+        # Load preprocessed image
+        sample = dict(**self.samples[idx])
+        if not self.load_in_memory:
+            sample["img"] = self.get_sample_img(idx)
 
-        for key in self.valid_datasets.keys():
-            self.valid_loaders[key] = DataLoader(
-                self.valid_datasets[key],
-                batch_size=self.batch_size["val"],
-                sampler=self.valid_samplers[key],
-                batch_sampler=self.valid_samplers[key],
-                shuffle=False,
-                num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-                pin_memory=self.pin_memory,
-                drop_last=False,
-                collate_fn=self.my_collate_function,
-                worker_init_fn=self.seed_worker,
-                generator=self.generator,
-            )
+        # Apply data augmentation
+        if self.augmentation_transforms:
+            sample["img"] = self.augmentation_transforms(image=np.array(sample["img"]))[
+                "image"
+            ]
 
-    @staticmethod
-    def seed_worker(worker_id):
-        worker_seed = torch.initial_seed() % 2**32
-        np.random.seed(worker_seed)
-        random.seed(worker_seed)
+        # Image normalization
+        sample["img"] = self.normalization_transforms(sample["img"])
 
-    def generate_test_loader(self, custom_name, sets_list):
-        """
-        Load test dataset, data sampler and data loader
-        """
-        if custom_name in self.test_loaders.keys():
-            return
-        paths_and_sets = list()
-        for set_info in sets_list:
-            paths_and_sets.append(
-                {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
-            )
-        self.test_datasets[custom_name] = self.dataset_class(
-            self.params,
-            "test",
-            custom_name,
-            paths_and_sets,
-            normalization_transforms=get_normalization_transforms(),
-        )
-        self.apply_specific_treatment_after_dataset_loading(
-            self.test_datasets[custom_name]
+        # Get final height and width
+        sample["img_reduced_shape"], sample["img_position"] = self.compute_final_size(
+            sample["img"]
         )
-        if self.params["use_ddp"]:
-            self.test_samplers[custom_name] = DistributedSampler(
-                self.test_datasets[custom_name],
-                num_replicas=self.params["num_gpu"],
-                rank=self.params["ddp_rank"],
-                shuffle=False,
-            )
-        else:
-            self.test_samplers[custom_name] = None
-        self.test_loaders[custom_name] = DataLoader(
-            self.test_datasets[custom_name],
-            batch_size=self.batch_size["test"],
-            sampler=self.test_samplers[custom_name],
-            shuffle=False,
-            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-            pin_memory=self.pin_memory,
-            drop_last=False,
-            collate_fn=self.my_collate_function,
-            worker_init_fn=self.seed_worker,
-            generator=self.generator,
-        )
-
-    def get_paths_and_sets(self, dataset_names_folds):
-        paths_and_sets = list()
-        for dataset_name, fold in dataset_names_folds:
-            path = self.params["datasets"][dataset_name]
-            paths_and_sets.append({"path": path, "set_name": fold})
-        return paths_and_sets
-
-
-class GenericDataset(Dataset):
-    """
-    Main class to handle dataset loading
-    """
-
-    def __init__(self, params, set_name, custom_name, paths_and_sets):
-        self.params = params
-        self.name = custom_name
-        self.set_name = set_name
 
-        self.preprocessing_transforms = get_preprocessing_transforms(
-            params["config"]["preprocessings"]
-        )
-        self.load_in_memory = (
-            self.params["config"]["load_in_memory"]
-            if "load_in_memory" in self.params["config"]
-            else True
+        # Convert label into tokens
+        sample["token_label"], sample["label_len"] = self.convert_sample_label(
+            sample["label"]
         )
-
-        # Load samples and preprocess images if load_in_memory is True
-        self.samples = self.load_samples(paths_and_sets)
-
-        self.curriculum_config = None
-
-    def __len__(self):
-        return len(self.samples)
+        return sample
 
     @staticmethod
     def load_image(path):
@@ -276,3 +130,31 @@ class GenericDataset(Dataset):
             return self.preprocessing_transforms(
                 self.load_image(self.samples[i]["path"])
             )
+
+    def compute_final_size(self, img):
+        """
+        Compute the final image size and position after feature extraction
+        """
+        final_c, final_h, final_w = img.shape
+        image_reduced_shape = np.ceil(
+            [final_h, final_w, final_c] / self.reduce_dims_factor
+        ).astype(int)
+
+        if self.set_name == "train":
+            image_reduced_shape = [max(1, t) for t in image_reduced_shape]
+
+        image_position = [
+            [0, final_h],
+            [0, final_w],
+        ]
+        return image_reduced_shape, image_position
+
+    def convert_sample_label(self, label):
+        """
+        Tokenize the label and return its length
+        """
+        token_label = token_to_ind(self.charset, label)
+        token_label.append(self.tokens["end"])
+        label_len = len(token_label)
+        token_label.insert(0, self.tokens["start"])
+        return token_label, label_len
diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index 1790f4a081cbaa265cf19483c7792f1f7af8e96b..383ac455aea8cd6ed104f81be85eb17bafec8952 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -1,35 +1,217 @@
 # -*- coding: utf-8 -*-
 import os
 import pickle
+import random
 
 import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
 
-from dan.manager.dataset import DatasetManager, GenericDataset
-from dan.utils import pad_images, pad_sequences_1D, token_to_ind
+from dan.manager.dataset import OCRDataset
+from dan.transforms import (
+    get_augmentation_transforms,
+    get_normalization_transforms,
+    get_preprocessing_transforms,
+)
+from dan.utils import pad_images, pad_sequences_1D
 
 
-class OCRDatasetManager(DatasetManager):
-    """
-    Specific class to handle OCR/HTR tasks
-    """
-
+class OCRDatasetManager:
     def __init__(self, params, device: str):
-        super(OCRDatasetManager, self).__init__(params, device)
+        self.params = params
 
-        self.dataset_class = OCRDataset
-        self.charset = (
-            params["charset"] if "charset" in params else self.get_merged_charsets()
-        )
+        # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
+        self.pin_memory = device != "cpu"
+
+        self.train_dataset = None
+        self.valid_datasets = dict()
+        self.test_datasets = dict()
+
+        self.train_loader = None
+        self.valid_loaders = dict()
+        self.test_loaders = dict()
+
+        self.train_sampler = None
+        self.valid_samplers = dict()
+        self.test_samplers = dict()
 
-        self.tokens = {"pad": len(self.charset) + 2}
-        self.tokens["end"] = len(self.charset)
-        self.tokens["start"] = len(self.charset) + 1
+        self.generator = torch.Generator()
+        self.generator.manual_seed(0)
+
+        self.batch_size = self.get_batch_size_by_set()
+
+        self.load_in_memory = (
+            self.params["config"]["load_in_memory"]
+            if "load_in_memory" in self.params["config"]
+            else True
+        )
+        self.charset = self.get_charset()
+        self.tokens = self.get_tokens()
         self.params["config"]["padding_token"] = self.tokens["pad"]
 
-    def get_merged_charsets(self):
+        self.my_collate_function = OCRCollateFunction(self.params["config"])
+        self.normalization = get_normalization_transforms()
+        self.augmentation = (
+            get_augmentation_transforms()
+            if self.params["config"]["augmentation"]
+            else None
+        )
+        self.preprocessing = get_preprocessing_transforms(
+            params["config"]["preprocessings"]
+        )
+
+    def load_datasets(self):
+        """
+        Load training and validation datasets
+        """
+        self.train_dataset = OCRDataset(
+            set_name="train",
+            paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]),
+            charset=self.charset,
+            tokens=self.tokens,
+            preprocessing_transforms=self.preprocessing,
+            normalization_transforms=self.normalization,
+            augmentation_transforms=self.augmentation,
+            load_in_memory=self.load_in_memory,
+        )
+
+        for custom_name in self.params["val"].keys():
+            self.valid_datasets[custom_name] = OCRDataset(
+                set_name="val",
+                paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]),
+                charset=self.charset,
+                tokens=self.tokens,
+                preprocessing_transforms=self.preprocessing,
+                normalization_transforms=self.normalization,
+                augmentation_transforms=None,
+                load_in_memory=self.load_in_memory,
+            )
+
+    def load_ddp_samplers(self):
+        """
+        Load training and validation data samplers
+        """
+        if self.params["use_ddp"]:
+            self.train_sampler = DistributedSampler(
+                self.train_dataset,
+                num_replicas=self.params["num_gpu"],
+                rank=self.params["ddp_rank"],
+                shuffle=True,
+            )
+            for custom_name in self.valid_datasets.keys():
+                self.valid_samplers[custom_name] = DistributedSampler(
+                    self.valid_datasets[custom_name],
+                    num_replicas=self.params["num_gpu"],
+                    rank=self.params["ddp_rank"],
+                    shuffle=False,
+                )
+        else:
+            for custom_name in self.valid_datasets.keys():
+                self.valid_samplers[custom_name] = None
+
+    def load_dataloaders(self):
+        """
+        Load training and validation data loaders
+        """
+        self.train_loader = DataLoader(
+            self.train_dataset,
+            batch_size=self.batch_size["train"],
+            shuffle=True if self.train_sampler is None else False,
+            drop_last=False,
+            batch_sampler=self.train_sampler,
+            sampler=self.train_sampler,
+            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
+            pin_memory=self.pin_memory,
+            collate_fn=self.my_collate_function,
+            worker_init_fn=self.seed_worker,
+            generator=self.generator,
+        )
+
+        for key in self.valid_datasets.keys():
+            self.valid_loaders[key] = DataLoader(
+                self.valid_datasets[key],
+                batch_size=self.batch_size["val"],
+                sampler=self.valid_samplers[key],
+                batch_sampler=self.valid_samplers[key],
+                shuffle=False,
+                num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
+                pin_memory=self.pin_memory,
+                drop_last=False,
+                collate_fn=self.my_collate_function,
+                worker_init_fn=self.seed_worker,
+                generator=self.generator,
+            )
+
+    @staticmethod
+    def seed_worker():
+        """
+        Set worker seed
+        """
+        worker_seed = torch.initial_seed() % 2**32
+        np.random.seed(worker_seed)
+        random.seed(worker_seed)
+
+    def generate_test_loader(self, custom_name, sets_list):
+        """
+        Load test dataset, data sampler and data loader
+        """
+        if custom_name in self.test_loaders.keys():
+            return
+        paths_and_sets = list()
+        for set_info in sets_list:
+            paths_and_sets.append(
+                {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
+            )
+        self.test_datasets[custom_name] = OCRDataset(
+            set_name="test",
+            paths_and_sets=paths_and_sets,
+            charset=self.charset,
+            tokens=self.tokens,
+            preprocessing_transforms=self.preprocessing,
+            normalization_transforms=self.normalization,
+            augmentation_transforms=None,
+            load_in_memory=self.load_in_memory,
+        )
+
+        if self.params["use_ddp"]:
+            self.test_samplers[custom_name] = DistributedSampler(
+                self.test_datasets[custom_name],
+                num_replicas=self.params["num_gpu"],
+                rank=self.params["ddp_rank"],
+                shuffle=False,
+            )
+        else:
+            self.test_samplers[custom_name] = None
+        self.test_loaders[custom_name] = DataLoader(
+            self.test_datasets[custom_name],
+            batch_size=self.batch_size["test"],
+            sampler=self.test_samplers[custom_name],
+            shuffle=False,
+            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
+            pin_memory=self.pin_memory,
+            drop_last=False,
+            collate_fn=self.my_collate_function,
+            worker_init_fn=self.seed_worker,
+            generator=self.generator,
+        )
+
+    def get_paths_and_sets(self, dataset_names_folds):
+        """
+        Set the right path for each data set
+        """
+        paths_and_sets = list()
+        for dataset_name, fold in dataset_names_folds:
+            path = self.params["datasets"][dataset_name]
+            paths_and_sets.append({"path": path, "set_name": fold})
+        return paths_and_sets
+
+    def get_charset(self):
         """
         Merge the charset of the different datasets used
         """
+        if "charset" in self.params:
+            return self.params["charset"]
         datasets = self.params["datasets"]
         charset = set()
         for key in datasets.keys():
@@ -39,83 +221,29 @@ class OCRDatasetManager(DatasetManager):
             charset.remove("")
         return sorted(list(charset))
 
-    def apply_specific_treatment_after_dataset_loading(self, dataset):
-        dataset.charset = self.charset
-        dataset.tokens = self.tokens
-        dataset.convert_labels()
-
-
-class OCRDataset(GenericDataset):
-    """
-    Specific class to handle OCR/HTR datasets
-    """
+    def get_tokens(self):
+        """
+        Get special tokens
+        """
+        return {
+            "end": len(self.charset),
+            "start": len(self.charset) + 1,
+            "pad": len(self.charset) + 2,
+        }
 
-    def __init__(
-        self,
-        params,
-        set_name,
-        custom_name,
-        paths_and_sets,
-        normalization_transforms,
-        augmentation_transforms=None,
-    ):
-        super(OCRDataset, self).__init__(params, set_name, custom_name, paths_and_sets)
-        self.charset = None
-        self.tokens = None
-        # Factor to reduce the height and width of the feature vector before feeding the decoder.
-        self.reduce_dims_factor = np.array([32, 8, 1])
-        self.collate_function = OCRCollateFunction
-        self.normalization_transforms = normalization_transforms
-        self.augmentation_transforms = augmentation_transforms
-
-    def __getitem__(self, idx):
-        sample = dict(**self.samples[idx])
-
-        if not self.load_in_memory:
-            sample["img"] = self.get_sample_img(idx)
-
-        # Data augmentation
-        if self.augmentation_transforms:
-            sample["img"] = self.augmentation_transforms(image=np.array(sample["img"]))[
-                "image"
-            ]
-
-        # Normalization
-        sample["img"] = self.normalization_transforms(sample["img"])
-
-        # Get final height and width
-        final_c, final_h, final_w = sample["img"].shape
-        sample["img_reduced_shape"] = np.ceil(
-            [final_h, final_w, final_c] / self.reduce_dims_factor
-        ).astype(int)
-
-        if self.set_name == "train":
-            sample["img_reduced_shape"] = [
-                max(1, t) for t in sample["img_reduced_shape"]
-            ]
-
-        sample["img_position"] = [
-            [0, final_h],
-            [0, final_w],
-        ]
-        return sample
-
-    def convert_labels(self):
-        """
-        Label str to token at character level
-        """
-        for i in range(len(self.samples)):
-            self.samples[i] = self.convert_sample_labels(self.samples[i])
-
-    def convert_sample_labels(self, sample):
-        label = sample["label"]
-
-        sample["label"] = label
-        sample["token_label"] = token_to_ind(self.charset, label)
-        sample["token_label"].append(self.tokens["end"])
-        sample["label_len"] = len(sample["token_label"])
-        sample["token_label"].insert(0, self.tokens["start"])
-        return sample
+    def get_batch_size_by_set(self):
+        """
+        Return batch size for each set
+        """
+        return {
+            "train": self.params["batch_size"],
+            "val": self.params["valid_batch_size"]
+            if "valid_batch_size" in self.params
+            else self.params["batch_size"],
+            "test": self.params["test_batch_size"]
+            if "test_batch_size" in self.params
+            else 1,
+        }
 
 
 class OCRCollateFunction: