diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py
index 819e90664ddc39ef410489fdecdfbefa7f1ecbfd..c2a33d89f8ed112ca9160d02deee1bf49e2c81c8 100644
--- a/dan/manager/dataset.py
+++ b/dan/manager/dataset.py
@@ -3,7 +3,6 @@ import json
 import os
 import random
 
-import cv2
 import numpy as np
 import torch
 from PIL import Image
@@ -11,7 +10,7 @@ from torch.utils.data import DataLoader, Dataset
 from torch.utils.data.distributed import DistributedSampler
 
 from dan.datasets.utils import natural_sort
-from dan.transforms import apply_data_augmentation
+from dan.transforms import get_augmentation_transforms, get_preprocessing_transforms
 
 
 class DatasetManager:
@@ -60,7 +59,13 @@ class DatasetManager:
             "train",
             self.params["train"]["name"],
             self.get_paths_and_sets(self.params["train"]["datasets"]),
+            augmentation_transforms=(
+                get_augmentation_transforms()
+                if self.params["config"]["augmentation"]
+                else None
+            ),
         )
+
         (
             self.params["config"]["mean"],
             self.params["config"]["std"],
@@ -77,6 +82,7 @@ class DatasetManager:
                 "val",
                 custom_name,
                 self.get_paths_and_sets(self.params["val"][custom_name]),
+                augmentation_transforms=None,
             )
             self.apply_specific_treatment_after_dataset_loading(
                 self.valid_datasets[custom_name]
@@ -209,7 +215,9 @@ class GenericDataset(Dataset):
             if "std" in params["config"].keys()
             else None
         )
-
+        self.preprocessing_transforms = get_preprocessing_transforms(
+            params["config"]["preprocessings"]
+        )
         self.load_in_memory = (
             self.params["config"]["load_in_memory"]
             if "load_in_memory" in self.params["config"]
@@ -221,7 +229,7 @@ class GenericDataset(Dataset):
         )
 
         if self.load_in_memory:
-            self.apply_preprocessing(params["config"]["preprocessings"])
+            self.preprocess_all_images()
 
         self.curriculum_config = None
 
@@ -231,11 +239,7 @@ class GenericDataset(Dataset):
     @staticmethod
     def load_image(path):
         with Image.open(path) as pil_img:
-            img = np.array(pil_img)
-            # grayscale images
-            if len(img.shape) == 2:
-                img = np.expand_dims(img, axis=2)
-        return img
+            return pil_img.convert("RGB")
 
     @staticmethod
     def load_samples(paths_and_sets, load_in_memory=True):
@@ -266,12 +270,12 @@ class GenericDataset(Dataset):
                     samples[-1]["img"] = GenericDataset.load_image(filename)
         return samples
 
-    def apply_preprocessing(self, preprocessings):
-        for i in range(len(self.samples)):
-            (
-                self.samples[i]["img"],
-                self.samples[i]["resize_ratio"],
-            ) = apply_preprocessing(self.samples[i]["img"], preprocessings)
+    def preprocess_all_images(self) -> None:
+        """
+        Iterate over all samples and apply pre-processing
+        """
+        for i, sample in enumerate(self.samples):
+            self.samples[i]["img"] = self.preprocessing_transforms(sample["img"])
 
     def compute_std_mean(self):
         """
@@ -285,15 +289,11 @@ class GenericDataset(Dataset):
         nb_pixels = 0
         for metric in ["mean", "std"]:
             for ind in range(len(self.samples)):
-                img = (
+                img = np.array(
                     self.get_sample_img(ind)
                     if self.load_in_memory
-                    else apply_preprocessing(
-                        self.get_sample_img(ind),
-                        self.params["config"]["preprocessings"],
-                    )[0]
+                    else self.preprocessing_transforms(self.get_sample_img(ind)),
                 )
-
                 if metric == "mean":
                     sum += np.sum(img, axis=(0, 1))
                     nb_pixels += np.prod(img.shape[:2])
@@ -307,19 +307,6 @@ class GenericDataset(Dataset):
                 self.std = np.sqrt(diff / nb_pixels)
         return self.mean, self.std
 
-    def apply_data_augmentation(self, img):
-        """
-        Apply data augmentation strategy on the input image
-        """
-        augs = [
-            self.params["config"][key] if key in self.params["config"].keys() else None
-            for key in ["augmentation", "valid_augmentation", "test_augmentation"]
-        ]
-        for aug, set_name in zip(augs, ["train", "val", "test"]):
-            if aug and self.set_name == set_name:
-                return apply_data_augmentation(img, aug)
-        return img
-
     def get_sample_img(self, i):
         """
         Get image by index
@@ -328,59 +315,3 @@ class GenericDataset(Dataset):
             return self.samples[i]["img"]
         else:
             return GenericDataset.load_image(self.samples[i]["path"])
-
-
-def apply_preprocessing(img, preprocessings):
-    """
-    Apply preprocessings on an image
-    """
-    resize_ratio = [1, 1]
-    for preprocessing in preprocessings:
-        if preprocessing["type"] == "to_grayscaled":
-            temp_img = img
-            h, w, c = temp_img.shape
-            if c == 3:
-                img = np.expand_dims(
-                    0.2125 * temp_img[:, :, 0]
-                    + 0.7154 * temp_img[:, :, 1]
-                    + 0.0721 * temp_img[:, :, 2],
-                    axis=2,
-                ).astype(np.uint8)
-
-        if preprocessing["type"] == "to_RGB":
-            temp_img = img
-            h, w, c = temp_img.shape
-            if c == 1:
-                img = np.concatenate([temp_img, temp_img, temp_img], axis=2)
-
-        if preprocessing["type"] == "resize":
-            keep_ratio = preprocessing["keep_ratio"]
-            max_h, max_w = preprocessing["max_height"], preprocessing["max_width"]
-            temp_img = img
-            h, w, c = temp_img.shape
-
-            ratio_h = max_h / h if max_h else 1
-            ratio_w = max_w / w if max_w else 1
-            if keep_ratio:
-                ratio_h = ratio_w = min(ratio_w, ratio_h)
-            new_h = min(max_h, int(h * ratio_h))
-            new_w = min(max_w, int(w * ratio_w))
-            temp_img = cv2.resize(temp_img, (new_w, new_h))
-            if len(temp_img.shape) == 2:
-                temp_img = np.expand_dims(temp_img, axis=2)
-
-            img = temp_img
-            resize_ratio = [ratio_h, ratio_w]
-
-        if preprocessing["type"] == "fixed_height":
-            new_h = preprocessing["height"]
-            temp_img = img
-            h, w, c = temp_img.shape
-            ratio = new_h / h
-            temp_img = cv2.resize(temp_img, (int(w * ratio), new_h))
-            if len(temp_img.shape) == 2:
-                temp_img = np.expand_dims(temp_img, axis=2)
-            img = temp_img
-            resize_ratio = [ratio, ratio]
-
-    return img, resize_ratio
diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index 0ee834d04c0e4fe56e41c18828474451a4e08eb7..ce62989f0c0a790ad584668542118a2be502e50c 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -1,12 +1,12 @@
 # -*- coding: utf-8 -*-
+import copy
 import os
 import pickle
 
-import cv2
 import numpy as np
 import torch
 
-from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing
+from dan.manager.dataset import DatasetManager, GenericDataset
 from dan.utils import pad_images, pad_sequences_1D, token_to_ind
 
 
@@ -52,41 +52,45 @@ class OCRDataset(GenericDataset):
     Specific class to handle OCR/HTR datasets
     """
 
-    def __init__(self, params, set_name, custom_name, paths_and_sets):
+    def __init__(
+        self,
+        params,
+        set_name,
+        custom_name,
+        paths_and_sets,
+        augmentation_transforms=None,
+    ):
         super(OCRDataset, self).__init__(params, set_name, custom_name, paths_and_sets)
         self.charset = None
         self.tokens = None
         # Factor to reduce the height and width of the feature vector before feeding the decoder.
         self.reduce_dims_factor = np.array([32, 8, 1])
         self.collate_function = OCRCollateFunction
+        self.augmentation_transforms = augmentation_transforms
 
     def __getitem__(self, idx):
-        sample = dict(**self.samples[idx])
+        sample = copy.deepcopy(self.samples[idx])
 
         if not self.load_in_memory:
             sample["img"] = self.get_sample_img(idx)
-            sample["img"], sample["resize_ratio"] = apply_preprocessing(
-                sample["img"], self.params["config"]["preprocessings"]
-            )
 
-        # Data augmentation
-        sample["img"] = self.apply_data_augmentation(sample["img"])
+        # Convert to numpy
+        sample["img"] = np.array(sample["img"])
 
-        if "max_size" in self.params["config"] and self.params["config"]["max_size"]:
-            max_ratio = max(
-                sample["img"].shape[0]
-                / self.params["config"]["max_size"]["max_height"],
-                sample["img"].shape[1] / self.params["config"]["max_size"]["max_width"],
-            )
-            if max_ratio > 1:
-                new_h, new_w = int(np.ceil(sample["img"].shape[0] / max_ratio)), int(
-                    np.ceil(sample["img"].shape[1] / max_ratio)
-                )
-                sample["img"] = cv2.resize(sample["img"], (new_w, new_h))
+        # Get initial height and width
+        initial_h, initial_w, _ = sample["img"].shape
+
+        # Data augmentation
+        if self.augmentation_transforms:
+            sample["img"] = self.augmentation_transforms(image=sample["img"])["image"]
 
         # Normalization
         sample["img"] = (sample["img"] - self.mean) / self.std
 
+        # Get final height and width (tensor mode)
+        final_h, final_w, _ = sample["img"].shape
+        sample["resize_ratio"] = [final_h / initial_h, final_w / initial_w]
+
         sample["img_reduced_shape"] = np.ceil(
             sample["img"].shape / self.reduce_dims_factor
         ).astype(int)
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index a303ad2f85896f6fea8c98370da10eea0f514d10..4312c15836e5b9533ed5c08e15b7105b7614ebc5 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -15,7 +15,7 @@ from dan.encoder import FCN_Encoder
 from dan.manager.training import Manager
 from dan.mlflow import MLFLOW_AVAILABLE
 from dan.schedulers import exponential_dropout_scheduler
-from dan.transforms import aug_config
+from dan.transforms import Preprocessing
 from dan.utils import MLflowNotInstalled
 
 if MLFLOW_AVAILABLE:
@@ -107,11 +107,12 @@ def get_config():
                 "worker_per_gpu": 4,  # Num of parallel processes per gpu for data loading
                 "preprocessings": [
                     {
-                        "type": "to_RGB",
-                        # if grayscaled image, produce RGB one (3 channels with same value) otherwise do nothing
-                    },
+                        "type": Preprocessing.MaxResize,
+                        "max_width": 2000,
+                        "max_height": 2000,
+                    }
                 ],
-                "augmentation": aug_config(0.9, 0.1),
+                "augmentation": True,
             },
         },
         "model_params": {
diff --git a/dan/transforms.py b/dan/transforms.py
index 5e9a151482536eace8ae9f8eee92455ed3c9cd67..1e34bce9aa9d2d463e47d25c679a7da39c49b88f 100644
--- a/dan/transforms.py
+++ b/dan/transforms.py
@@ -2,399 +2,221 @@
 """
 Each transform class defined here takes as input a PIL Image and returns the modified PIL Image
 """
-import math
+from enum import Enum
+from random import randint
 
-import cv2
+import albumentations as A
 import numpy as np
-from cv2 import dilate, erode, normalize
-from numpy import random
-from PIL import Image
-from torch import rand, randint
-from torch.distributions.uniform import Uniform
-from torchvision.transforms import (
+from albumentations.augmentations import (
+    Affine,
+    CoarseDropout,
     ColorJitter,
+    Downscale,
+    ElasticTransform,
     GaussianBlur,
-    RandomCrop,
-    RandomPerspective,
+    GaussNoise,
+    Perspective,
+    PiecewiseAffine,
+    Sharpen,
+    ToGray,
 )
-from torchvision.transforms.functional import InterpolationMode
+from cv2 import dilate, erode
+from numpy import random
+from PIL import Image
+from torch.distributions.uniform import Uniform
+from torchvision.transforms import Compose
+from torchvision.transforms.functional import resize
 
 
-class DPIAdjusting:
+class Preprocessing(str, Enum):
+    # If the image is bigger than the given size, resize it while keeping the original ratio
+    MaxResize = "max_resize"
+    # Resize the height to a fixed value while keeping the original ratio
+    FixedHeightResize = "fixed_height_resize"
+    # Resize the width to a fixed value while keeping the original ratio
+    FixedWidthResize = "fixed_width_resize"
+
+
+class FixedHeightResize:
     """
-    Resolution modification
+    Resize an image to a fixed height
     """
 
-    def __init__(self, factor):
-        self.factor = factor
+    def __init__(self, height: int) -> None:
+        self.height = height
 
-    def __call__(self, x):
-        w, h = x.size
-        return x.resize(
-            (int(np.ceil(w * self.factor)), int(np.ceil(h * self.factor))),
-            Image.BILINEAR,
-        )
+    def __call__(self, img: Image) -> Image:
+        size = (self.height, self._calc_new_width(img))
+        return resize(img, size)
 
+    def _calc_new_width(self, img: Image) -> int:
+        aspect_ratio = img.width / img.height
+        return round(self.height * aspect_ratio)
 
-class Dilation:
+
+class FixedWidthResize:
     """
-    OCR: stroke width increasing
+    Resize an image to a fixed width
     """
 
-    def __init__(self, kernel, iterations):
-        self.kernel = np.ones(kernel, np.uint8)
-        self.iterations = iterations
+    def __init__(self, width: int) -> None:
+        self.width = width
 
-    def __call__(self, x):
-        return Image.fromarray(
-            dilate(np.array(x), self.kernel, iterations=self.iterations)
-        )
+    def __call__(self, img: Image) -> Image:
+        size = (self._calc_new_height(img), self.width)
+        return resize(img, size)
 
+    def _calc_new_height(self, img: Image) -> int:
+        aspect_ratio = img.height / img.width
+        return round(self.width * aspect_ratio)
 
-class Erosion:
+
+class MaxResize:
     """
-    OCR: stroke width decreasing
+    Resize an image if it is bigger than the maximum size
     """
 
-    def __init__(self, kernel, iterations):
-        self.kernel = np.ones(kernel, np.uint8)
-        self.iterations = iterations
+    def __init__(self, height: int, width: int) -> None:
+        self.max_width = width
+        self.max_height = height
 
-    def __call__(self, x):
-        return Image.fromarray(
-            erode(np.array(x), self.kernel, iterations=self.iterations)
-        )
+    def __call__(self, img: Image) -> Image:
+        width, height = img.width, img.height
+        if width <= self.max_width and height <= self.max_height:
+            return img
+        width_ratio = self.max_width / width
+        height_ratio = self.max_height / height
+        ratio = min(height_ratio, width_ratio)
+        new_width = int(width * ratio)
+        new_height = int(height * ratio)
+        return resize(img, (new_height, new_width))
 
 
-class GaussianNoise:
+class DPIAdjusting:
     """
-    Add Gaussian Noise
+    Resolution modification
     """
 
-    def __init__(self, std):
-        self.std = std
+    def __init__(self, min_factor: float, max_factor: float, p=1.0):
+        self.min_factor = min_factor
+        self.max_factor = max_factor
+        self.p = p
+        self.always_apply = False
 
-    def __call__(self, x):
-        x_np = np.array(x)
-        mean, std = np.mean(x_np), np.std(x_np)
-        std = math.copysign(max(abs(std), 0.000001), std)
-        min_, max_ = np.min(
-            x_np,
-        ), np.max(x_np)
-        normal_noise = np.random.randn(*x_np.shape)
-        if (
-            len(x_np.shape) == 3
-            and x_np.shape[2] == 3
-            and np.all(x_np[:, :, 0] == x_np[:, :, 1])
-            and np.all(x_np[:, :, 0] == x_np[:, :, 2])
-        ):
-            normal_noise[:, :, 1] = normal_noise[:, :, 2] = normal_noise[:, :, 0]
-        x_np = ((x_np - mean) / std + normal_noise * self.std) * std + mean
-        x_np = normalize(x_np, x_np, max_, min_, cv2.NORM_MINMAX)
-
-        return Image.fromarray(x_np.astype(np.uint8))
-
-
-class Sharpen:
+    def __call__(self, image: np.array, force_apply: bool = False) -> Image:
+        if not (random.random() <= self.p or self.always_apply or force_apply):
+            return {"image": image}
+        factor = float(Uniform(self.min_factor, self.max_factor).sample())
+        image = Image.fromarray(image)
+        augmented_image = image.resize(
+            (int(np.ceil(image.width * factor)), int(np.ceil(image.height * factor))),
+            Image.BILINEAR,
+        )
+        return {"image": np.array(augmented_image)}
+
+
+class Dilation:
     """
-    Add Gaussian Noise
+    OCR: stroke width increasing
     """
 
-    def __init__(self, alpha, strength):
-        self.alpha = alpha
-        self.strength = strength
+    def __init__(self, kernel, iterations):
+        self.kernel = kernel
+        self.iterations = iterations
 
     def __call__(self, x):
-        x_np = np.array(x)
-        id_matrix = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]])
-        effect_matrix = np.array([[1, 1, 1], [1, -(8 + self.strength), 1], [1, 1, 1]])
-        kernel = (1 - self.alpha) * id_matrix - self.alpha * effect_matrix
-        kernel = np.expand_dims(kernel, axis=2)
-        kernel = np.concatenate([kernel, kernel, kernel], axis=2)
-        sharpened = cv2.filter2D(x_np, -1, kernel=kernel[:, :, 0])
-        return Image.fromarray(sharpened.astype(np.uint8))
+        return dilate(np.array(x), self.kernel, iterations=self.iterations)
 
 
-class ZoomRatio:
+class Erosion:
     """
-    Crop by ratio
-    Preserve dimensions if keep_dim = True (= zoom)
+    OCR: stroke width decreasing
     """
 
-    def __init__(self, ratio_h, ratio_w, keep_dim=True):
-        self.ratio_w = ratio_w
-        self.ratio_h = ratio_h
-        self.keep_dim = keep_dim
-
-    def __call__(self, x):
-        w, h = x.size
-        x = RandomCrop((int(h * self.ratio_h), int(w * self.ratio_w)))(x)
-        if self.keep_dim:
-            x = x.resize((w, h), Image.BILINEAR)
-        return x
-
-
-class ElasticDistortion:
-    def __init__(self, kernel_size=(7, 7), sigma=5, alpha=1):
-        self.kernel_size = kernel_size
-        self.sigma = sigma
-        self.alpha = alpha
+    def __init__(self, kernel, iterations):
+        self.kernel = kernel
+        self.iterations = iterations
 
     def __call__(self, x):
-        x_np = np.array(x)
-
-        h, w = x_np.shape[:2]
-
-        dx = np.random.uniform(-1, 1, (h, w))
-        dy = np.random.uniform(-1, 1, (h, w))
-
-        x_gauss = cv2.GaussianBlur(dx, self.kernel_size, self.sigma)
-        y_gauss = cv2.GaussianBlur(dy, self.kernel_size, self.sigma)
-
-        n = np.sqrt(x_gauss**2 + y_gauss**2)
+        return erode(np.array(x), self.kernel, iterations=self.iterations)
 
-        nd_x = self.alpha * x_gauss / n
-        nd_y = self.alpha * y_gauss / n
 
-        ind_y, ind_x = np.indices((h, w), dtype=np.float32)
-
-        map_x = nd_x + ind_x
-        map_x = map_x.reshape(h, w).astype(np.float32)
-        map_y = nd_y + ind_y
-        map_y = map_y.reshape(h, w).astype(np.float32)
+class ErosionDilation:
+    """
+    Random erosion or dilation
+    """
 
-        dst = cv2.remap(x_np, map_x, map_y, cv2.INTER_LINEAR)
-        return Image.fromarray(dst.astype(np.uint8))
+    def __init__(self, min_kernel, max_kernel, iterations, p=1.0):
+        self.min_kernel = min_kernel
+        self.max_kernel = max_kernel
+        self.iterations = iterations
+        self.p = p
+        self.always_apply = False
+
+    def __call__(self, image, force_apply=False):
+        if not (random.random() <= self.p or self.always_apply or force_apply):
+            return {"image": image}
+        kernel_h = randint(self.min_kernel, self.max_kernel)
+        kernel_w = randint(self.min_kernel, self.max_kernel)
+        kernel = np.ones((kernel_h, kernel_w), np.uint8)
+        augmented_image = (
+            Erosion(kernel, iterations=self.iterations)(image)
+            if random.random() < 0.5
+            else Dilation(kernel=kernel, iterations=self.iterations)(image)
+        )
+        return {"image": augmented_image}
 
 
-def get_list_augmenters(img, aug_configs, fill_value):
+def get_preprocessing_transforms(preprocessings: list) -> Compose:
     """
-    Randomly select a list of data augmentation techniques to used based on aug_configs
+    Returns a list of transformation to be applied to the image.
     """
-    augmenters = list()
-    for aug_config in aug_configs:
-        if rand((1,)) > aug_config["proba"]:
-            continue
-        if aug_config["type"] == "dpi":
-            valid_factor = False
-            while not valid_factor:
-                factor = float(
-                    Uniform(aug_config["min_factor"], aug_config["max_factor"]).sample()
-                )
-                valid_factor = not (
-                    (
-                        "max_width" in aug_config
-                        and factor * img.size[0] > aug_config["max_width"]
-                    )
-                    or (
-                        "max_height" in aug_config
-                        and factor * img.size[1] > aug_config["max_height"]
-                    )
-                    or (
-                        "min_width" in aug_config
-                        and factor * img.size[0] < aug_config["min_width"]
-                    )
-                    or (
-                        "min_height" in aug_config
-                        and factor * img.size[1] < aug_config["min_height"]
+    # Convert to Tensor for torchvision transforms
+    transforms = []
+    for preprocessing in preprocessings:
+        match preprocessing["type"]:
+            case Preprocessing.MaxResize:
+                transforms.append(
+                    MaxResize(
+                        height=preprocessing["max_height"],
+                        width=preprocessing["max_width"],
                     )
                 )
-            augmenters.append(DPIAdjusting(factor))
-
-        elif aug_config["type"] == "zoom_ratio":
-            ratio_h = float(
-                Uniform(aug_config["min_ratio_h"], aug_config["max_ratio_h"]).sample()
-            )
-            ratio_w = float(
-                Uniform(aug_config["min_ratio_w"], aug_config["max_ratio_w"]).sample()
-            )
-            augmenters.append(
-                ZoomRatio(
-                    ratio_h=ratio_h, ratio_w=ratio_w, keep_dim=aug_config["keep_dim"]
-                )
-            )
-
-        elif aug_config["type"] == "perspective":
-            scale = float(
-                Uniform(aug_config["min_factor"], aug_config["max_factor"]).sample()
-            )
-            augmenters.append(
-                RandomPerspective(
-                    distortion_scale=scale,
-                    p=1,
-                    interpolation=InterpolationMode.BILINEAR,
-                    fill=fill_value,
-                )
-            )
-
-        elif aug_config["type"] == "elastic_distortion":
-            kernel_size = (
-                randint(
-                    aug_config["min_kernel_size"], aug_config["max_kernel_size"], (1,)
-                ).item()
-            ) // 2 * 2 + 1
-            sigma = float(
-                Uniform(aug_config["min_sigma"], aug_config["max_sigma"]).sample()
-            )
-            alpha = float(
-                Uniform(aug_config["min_alpha"], aug_config["max_alpha"]).sample()
-            )
-            augmenters.append(
-                ElasticDistortion(
-                    kernel_size=(kernel_size, kernel_size), sigma=sigma, alpha=alpha
-                )
-            )
-
-        elif aug_config["type"] == "dilation_erosion":
-            kernel_h = randint(
-                aug_config["min_kernel"], aug_config["max_kernel"] + 1, (1,)
-            )
-            kernel_w = randint(
-                aug_config["min_kernel"], aug_config["max_kernel"] + 1, (1,)
-            )
-            if randint(0, 2, (1,)) == 0:
-                augmenters.append(
-                    Erosion((kernel_w, kernel_h), aug_config["iterations"])
-                )
-            else:
-                augmenters.append(
-                    Dilation((kernel_w, kernel_h), aug_config["iterations"])
+            case Preprocessing.FixedHeightResize:
+                transforms.append(
+                    FixedHeightResize(height=preprocessing["fixed_height"])
                 )
-
-        elif aug_config["type"] == "color_jittering":
-            augmenters.append(
-                ColorJitter(
-                    contrast=aug_config["factor_contrast"],
-                    brightness=aug_config["factor_brightness"],
-                    saturation=aug_config["factor_saturation"],
-                    hue=aug_config["factor_hue"],
-                )
-            )
-
-        elif aug_config["type"] == "gaussian_blur":
-            max_kernel_h = min(aug_config["max_kernel"], img.size[1])
-            max_kernel_w = min(aug_config["max_kernel"], img.size[0])
-            kernel_h = (
-                randint(aug_config["min_kernel"], max_kernel_h + 1, (1,)).item()
-            ) // 2 * 2 + 1
-            kernel_w = (
-                randint(aug_config["min_kernel"], max_kernel_w + 1, (1,)).item()
-            ) // 2 * 2 + 1
-            sigma = float(
-                Uniform(aug_config["min_sigma"], aug_config["max_sigma"]).sample()
-            )
-            augmenters.append(
-                GaussianBlur(kernel_size=(kernel_w, kernel_h), sigma=sigma)
-            )
-
-        elif aug_config["type"] == "gaussian_noise":
-            augmenters.append(GaussianNoise(std=aug_config["std"]))
-
-        elif aug_config["type"] == "sharpen":
-            alpha = float(
-                Uniform(aug_config["min_alpha"], aug_config["max_alpha"]).sample()
-            )
-            strength = float(
-                Uniform(aug_config["min_strength"], aug_config["max_strength"]).sample()
-            )
-            augmenters.append(Sharpen(alpha=alpha, strength=strength))
-
-        else:
-            print("Error - unknown augmentor: {}".format(aug_config["type"]))
-            exit(-1)
-
-    return augmenters
-
-
-def apply_data_augmentation(img, da_config):
-    """
-    Apply data augmentation strategy on input image
-    """
-    if da_config["proba"] != 1 and rand((1,)) > da_config["proba"]:
-        return img
-
-    # Convert to PIL Image
-    img = img[:, :, 0] if img.shape[2] == 1 else img
-    img = Image.fromarray(img)
-
-    fill_value = da_config["fill_value"] if "fill_value" in da_config else 255
-    augmenters = get_list_augmenters(
-        img, da_config["augmentations"], fill_value=fill_value
-    )
-    if da_config["order"] == "random":
-        random.shuffle(augmenters)
-
-    for augmenter in augmenters:
-        img = augmenter(img)
-
-    # convert to numpy array
-    img = np.array(img)
-    img = np.expand_dims(img, axis=2) if len(img.shape) == 2 else img
-    return img
-
-
-def aug_config(proba_use_da, p):
-    return {
-        "order": "random",
-        "proba": proba_use_da,
-        "augmentations": [
-            {
-                "type": "dpi",
-                "proba": p,
-                "min_factor": 0.75,
-                "max_factor": 1,
-            },
-            {
-                "type": "perspective",
-                "proba": p,
-                "min_factor": 0,
-                "max_factor": 0.4,
-            },
-            {
-                "type": "elastic_distortion",
-                "proba": p,
-                "min_alpha": 0.5,
-                "max_alpha": 1,
-                "min_sigma": 1,
-                "max_sigma": 10,
-                "min_kernel_size": 3,
-                "max_kernel_size": 9,
-            },
-            {
-                "type": "dilation_erosion",
-                "proba": p,
-                "min_kernel": 1,
-                "max_kernel": 3,
-                "iterations": 1,
-            },
-            {
-                "type": "color_jittering",
-                "proba": p,
-                "factor_hue": 0.2,
-                "factor_brightness": 0.4,
-                "factor_contrast": 0.4,
-                "factor_saturation": 0.4,
-            },
-            {
-                "type": "gaussian_blur",
-                "proba": p,
-                "min_kernel": 3,
-                "max_kernel": 5,
-                "min_sigma": 3,
-                "max_sigma": 5,
-            },
-            {
-                "type": "gaussian_noise",
-                "proba": p,
-                "std": 0.5,
-            },
-            {
-                "type": "sharpen",
-                "proba": p,
-                "min_alpha": 0,
-                "max_alpha": 1,
-                "min_strength": 0,
-                "max_strength": 1,
-            },
+            case Preprocessing.FixedWidthResize:
+                transforms.append(FixedWidthResize(width=preprocessing["fixed_width"]))
+    return Compose(transforms)
+
+
+def get_augmentation_transforms() -> A.Compose:
+    """
+    Returns a list of transformation to be applied to the image.
+    """
+    return A.SomeOf(
+        [
+            DPIAdjusting(min_factor=0.75, max_factor=1),
+            Perspective(scale=(0.05, 0.09), fit_output=True),
+            GaussianBlur(sigma_limit=2.5),
+            GaussNoise(var_limit=50**2),
+            ColorJitter(contrast=0.2, brightness=0.2, saturation=0.2, hue=0.2),
+            A.OneOf(
+                [
+                    ElasticTransform(
+                        alpha=20.0, sigma=5.0, alpha_affine=1.0, border_mode=0
+                    ),
+                    PiecewiseAffine(scale=(0.01, 0.04), nb_rows=1, nb_cols=4),
+                ]
+            ),
+            Sharpen(alpha=(0.0, 1.0)),
+            ErosionDilation(min_kernel=1, max_kernel=4, iterations=1),
+            Affine(shear={"x": (-20, 20), "y": (0, 0)}),
+            CoarseDropout(),
+            Downscale(scale_min=0.5, scale_max=0.9),
+            ToGray(),
         ],
-    }
+        n=2,
+        p=0.9,
+    )
diff --git a/docs/assets/augmentations/document_color_jitter.png b/docs/assets/augmentations/document_color_jitter.png
new file mode 100644
index 0000000000000000000000000000000000000000..107aa41ec8dc57bcd2865d428cd7a192c9641df0
Binary files /dev/null and b/docs/assets/augmentations/document_color_jitter.png differ
diff --git a/docs/assets/augmentations/document_downscale.png b/docs/assets/augmentations/document_downscale.png
new file mode 100644
index 0000000000000000000000000000000000000000..4ac1affc9ae12460d4e763c4e3bbe5b6ff410f71
Binary files /dev/null and b/docs/assets/augmentations/document_downscale.png differ
diff --git a/docs/assets/augmentations/document_dropout.png b/docs/assets/augmentations/document_dropout.png
new file mode 100644
index 0000000000000000000000000000000000000000..16b39846306b804a622217e2a4951ddd200b915d
Binary files /dev/null and b/docs/assets/augmentations/document_dropout.png differ
diff --git a/docs/assets/augmentations/document_elastic.png b/docs/assets/augmentations/document_elastic.png
new file mode 100644
index 0000000000000000000000000000000000000000..0cdf1962868f28d727adc73248279c962e76575d
Binary files /dev/null and b/docs/assets/augmentations/document_elastic.png differ
diff --git a/docs/assets/augmentations/document_erosion_dilation.png b/docs/assets/augmentations/document_erosion_dilation.png
new file mode 100644
index 0000000000000000000000000000000000000000..997c677476c6c759a70044f154bb6395b64e16ad
Binary files /dev/null and b/docs/assets/augmentations/document_erosion_dilation.png differ
diff --git a/docs/assets/augmentations/document_full_pipeline.png b/docs/assets/augmentations/document_full_pipeline.png
new file mode 100644
index 0000000000000000000000000000000000000000..e93588f8a35dc2278b22f30f8b908c8efff40e82
Binary files /dev/null and b/docs/assets/augmentations/document_full_pipeline.png differ
diff --git a/docs/assets/augmentations/document_full_pipeline_2.png b/docs/assets/augmentations/document_full_pipeline_2.png
new file mode 100644
index 0000000000000000000000000000000000000000..1fb373961f79e69acc48307962c80fdf6497407b
Binary files /dev/null and b/docs/assets/augmentations/document_full_pipeline_2.png differ
diff --git a/docs/assets/augmentations/document_gaussian_blur.png b/docs/assets/augmentations/document_gaussian_blur.png
new file mode 100644
index 0000000000000000000000000000000000000000..c58b35313c11ab71c5ba74fc4ea8583955f58a92
Binary files /dev/null and b/docs/assets/augmentations/document_gaussian_blur.png differ
diff --git a/docs/assets/augmentations/document_gaussian_noise.png b/docs/assets/augmentations/document_gaussian_noise.png
new file mode 100644
index 0000000000000000000000000000000000000000..e85f4342cee9affc582eb2ec7173508e30068c31
Binary files /dev/null and b/docs/assets/augmentations/document_gaussian_noise.png differ
diff --git a/docs/assets/augmentations/document_grayscale.png b/docs/assets/augmentations/document_grayscale.png
new file mode 100644
index 0000000000000000000000000000000000000000..b367a174f30591bc9bfb3d3722f13cd66a15f983
Binary files /dev/null and b/docs/assets/augmentations/document_grayscale.png differ
diff --git a/docs/assets/augmentations/document_perspective.png b/docs/assets/augmentations/document_perspective.png
new file mode 100644
index 0000000000000000000000000000000000000000..49120af69db46c1ab4f434a2ce0d16005610b420
Binary files /dev/null and b/docs/assets/augmentations/document_perspective.png differ
diff --git a/docs/assets/augmentations/document_piecewise.png b/docs/assets/augmentations/document_piecewise.png
new file mode 100644
index 0000000000000000000000000000000000000000..0dce066028518009409690760e23c432402a8fe5
Binary files /dev/null and b/docs/assets/augmentations/document_piecewise.png differ
diff --git a/docs/assets/augmentations/document_sharpen.png b/docs/assets/augmentations/document_sharpen.png
new file mode 100644
index 0000000000000000000000000000000000000000..721b1d38d09867f1f2e7bc400ae3a2b9609daf7f
Binary files /dev/null and b/docs/assets/augmentations/document_sharpen.png differ
diff --git a/docs/assets/augmentations/document_shearx.png b/docs/assets/augmentations/document_shearx.png
new file mode 100644
index 0000000000000000000000000000000000000000..1fa6727c78a32f057f20eb5d4bb8f0617c732f12
Binary files /dev/null and b/docs/assets/augmentations/document_shearx.png differ
diff --git a/docs/assets/augmentations/line_color_jitter.png b/docs/assets/augmentations/line_color_jitter.png
new file mode 100644
index 0000000000000000000000000000000000000000..55bc3a2f05eae9f5d976aeb0b9d0d22427d79e40
Binary files /dev/null and b/docs/assets/augmentations/line_color_jitter.png differ
diff --git a/docs/assets/augmentations/line_downscale.png b/docs/assets/augmentations/line_downscale.png
new file mode 100644
index 0000000000000000000000000000000000000000..1eca4194bfe4a1bc2e17c493cf7f3cb89b87f2e9
Binary files /dev/null and b/docs/assets/augmentations/line_downscale.png differ
diff --git a/docs/assets/augmentations/line_dropout.png b/docs/assets/augmentations/line_dropout.png
new file mode 100644
index 0000000000000000000000000000000000000000..3a835456d75e1f917339146052b9124fd63c29a0
Binary files /dev/null and b/docs/assets/augmentations/line_dropout.png differ
diff --git a/docs/assets/augmentations/line_elastic.png b/docs/assets/augmentations/line_elastic.png
new file mode 100644
index 0000000000000000000000000000000000000000..0c3e873092f09c2a4c03737ba87e1cb4194f0903
Binary files /dev/null and b/docs/assets/augmentations/line_elastic.png differ
diff --git a/docs/assets/augmentations/line_erosion_dilation.png b/docs/assets/augmentations/line_erosion_dilation.png
new file mode 100644
index 0000000000000000000000000000000000000000..ac920e9dc2eae3673aa03b3199b0efc949bdc0d7
Binary files /dev/null and b/docs/assets/augmentations/line_erosion_dilation.png differ
diff --git a/docs/assets/augmentations/line_full_pipeline.png b/docs/assets/augmentations/line_full_pipeline.png
new file mode 100644
index 0000000000000000000000000000000000000000..0db8d89e6f1121dfed7e965be323d623b6c3a178
Binary files /dev/null and b/docs/assets/augmentations/line_full_pipeline.png differ
diff --git a/docs/assets/augmentations/line_gaussian_blur.png b/docs/assets/augmentations/line_gaussian_blur.png
new file mode 100644
index 0000000000000000000000000000000000000000..364ecb0a3700164b2bb64fb3b96296f6a1e41ec9
Binary files /dev/null and b/docs/assets/augmentations/line_gaussian_blur.png differ
diff --git a/docs/assets/augmentations/line_gaussian_noise.png b/docs/assets/augmentations/line_gaussian_noise.png
new file mode 100644
index 0000000000000000000000000000000000000000..2667b5ecdf599239705292ec947c11594d735ef7
Binary files /dev/null and b/docs/assets/augmentations/line_gaussian_noise.png differ
diff --git a/docs/assets/augmentations/line_grayscale.png b/docs/assets/augmentations/line_grayscale.png
new file mode 100644
index 0000000000000000000000000000000000000000..92cec3ded0f887bb23772423a77d56b1b035d56e
Binary files /dev/null and b/docs/assets/augmentations/line_grayscale.png differ
diff --git a/docs/assets/augmentations/line_perspective.png b/docs/assets/augmentations/line_perspective.png
new file mode 100644
index 0000000000000000000000000000000000000000..8c0f596b49634b1a20e43cae3f8dfcbbb1764860
Binary files /dev/null and b/docs/assets/augmentations/line_perspective.png differ
diff --git a/docs/assets/augmentations/line_piecewise.png b/docs/assets/augmentations/line_piecewise.png
new file mode 100644
index 0000000000000000000000000000000000000000..1b7181d3278fdc7b784fad2e2a00ff00427465fd
Binary files /dev/null and b/docs/assets/augmentations/line_piecewise.png differ
diff --git a/docs/assets/augmentations/line_sharpen.png b/docs/assets/augmentations/line_sharpen.png
new file mode 100644
index 0000000000000000000000000000000000000000..b50da170860b0361c8bbb2839bad0728e53b30f3
Binary files /dev/null and b/docs/assets/augmentations/line_sharpen.png differ
diff --git a/docs/assets/augmentations/line_shearx.png b/docs/assets/augmentations/line_shearx.png
new file mode 100644
index 0000000000000000000000000000000000000000..0aea6afbdc2748f563f9067f09b1854db6d68eb4
Binary files /dev/null and b/docs/assets/augmentations/line_shearx.png differ
diff --git a/docs/usage/train/augmentation.md b/docs/usage/train/augmentation.md
new file mode 100644
index 0000000000000000000000000000000000000000..425e18d72c9ad28cc05b6df232da7c2aea269873
--- /dev/null
+++ b/docs/usage/train/augmentation.md
@@ -0,0 +1,137 @@
+# Data augmentation transforms
+
+This page lists data augmentation transforms used in DAN.
+
+## Individual augmentation transforms
+
+### Elastic transform
+
+|                              | Elastic Transforms                                                                                                                                                                            |
+| ---------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Description                  | This transformation applies local distortions that rotate characters locally                                                                                                                  |
+| Comments                     | The impact of this transformation is mostly visible on documents, not so much on lines. Results are comparable to the original DAN implementation.                                            |
+| Documentation                | See the [`albumentations` documentation](https://albumentations.ai/docs/api_reference/augmentations/geometric/transforms/#albumentations.augmentations.geometric.transforms.ElasticTransform) |
+| Examples                     | ![](../../assets/augmentations/line_elastic.png) ![](../../assets/augmentations/document_elastic.png)                                                                                         |
+| CPU time (seconds/10 images) | 0.44 (3013x128 pixels) / 0.86 (1116x581 pixels)                                                                                                                                               |
+
+### PieceWise Affine
+
+|                              | PieceWiseAffine                                                                                                                                                                              |
+| ---------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Description                  | This transformation also applies local distortions but with a larger grid than EslasticTransform.                                                                                            |
+| Comments                     | This transformation is very slow. It is a new transforms that was not in the original DAN implementation.                                                                                    |
+| Documentation                | See the [`albumentations` documentation](https://albumentations.ai/docs/api_reference/augmentations/geometric/transforms/#albumentations.augmentations.geometric.transforms.PiecewiseAffine) |
+| Examples                     | ![](../../assets/augmentations/line_piecewise.png) ![](../../assets/augmentations/document_piecewise.png)                                                                                    |
+| CPU time (seconds/10 images) |                                                                                                                                                                                              | 2.92 (3013x128 pixels) / 3.76 (1116x581 pixels)
+
+### Dilation Erosion
+
+|                              | Dilation & erosion                                                                                                                                                                        |
+| ---------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Description                  | This transformation makes the pen stroke thicker or thinner.                                                                                                                              |
+| Comments                     | The `RandomDilationErosion` class randomly select a kernel size and applies a dilation or an erosion to the image. It relies on opencv and is similar to the original DAN implementation. |
+| Documentation                | See the [`opencv` documentation](https://docs.opencv.org/3.4/db/df6/tutorial_erosion_dilatation.html)                                                                                     |
+| Examples                     | ![](../../assets/augmentations/line_erosion_dilation.png) ![](../../assets/augmentations/document_erosion_dilation.png)                                                                   |
+| CPU time (seconds/10 images) | 0.02 (3013x128 pixels) / 0.03 (1116x581 pixels)                                                                                                                                           |
+
+### Sharpen
+
+|                              | Sharpen                                                                                                                                                          |
+| ---------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Description                  | This transformation makes the image sharper.                                                                                                                     |
+| Comments                     | Similar to the original DAN implementation.                                                                                                                      |
+| Documentation                | See the [`albumentations` documentation](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Sharpen) |
+| Examples                     | ![](../../assets/augmentations/line_sharpen.png) ![](../../assets/augmentations/document_sharpen.png)                                                            |
+| CPU time (seconds/10 images) | 0.02 (3013x128 pixels) / 0.04 (1116x581 pixels)                                                                                                                  |
+
+### Color Jittering
+
+|                              | Color jittering                                                                                                                                                      |
+| ---------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Description                  | This transformation alters the colors in the image.                                                                                                                  |
+| Comments                     | Similar to the original DAN implementation.                                                                                                                          |
+| Documentation                | See the [`albumentations` documentation](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ColorJitter) |
+| Examples                     | ![](../../assets/augmentations/line_color_jitter.png) ![](../../assets/augmentations/document_color_jitter.png)                                                      |
+| CPU time (seconds/10 images) | 0.03 (3013x128 pixels) / 0.04 (1116x581 pixels)                                                                                                                      |
+
+### Gaussian noise
+
+|                              | Gaussian noise                                                                                                                                                         |
+| ---------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Description                  | This transformation adds Gaussian noise to the image.                                                                                                                  |
+| Comments                     | The noise from the original DAN implementation is more uniform.                                                                                                        |
+| Documentation                | See the [`albumentations` documentation](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.GaussianNoise) |
+| Examples                     | ![](../../assets/augmentations/line_gaussian_noise.png) ![](../../assets/augmentations/document_gaussian_noise.png)                                                    |
+| CPU time (seconds/10 images) | 0.29 (3013x128 pixels) / 0.53 (1116x581 pixels)                                                                                                                        |
+
+### Gaussian blur
+
+|                              | Gaussian blur                                                                                                                                                         |
+| ---------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Description                  | This transformation blurs the image.                                                                                                                                  |
+| Comments                     | Similar to the original DAN implementation.                                                                                                                           |
+| Documentation                | See the [`albumentations` documentation](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.GaussianBlur) |
+| Examples                     | ![](../../assets/augmentations/line_gaussian_blur.png) ![](../../assets/augmentations/document_gaussian_blur.png)                                                     |
+| CPU time (seconds/10 images) | 0.01 (3013x128 pixels) / 0.02 (1116x581 pixels)                                                                                                                       |
+
+### Random perspective
+
+|                              | Random perspective                                                                                                                                                   |
+| ---------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Description                  | This transformation changes the perspective from which the photo is taken.                                                                                           |
+| Comments                     | Similar to the original DAN implementation.                                                                                                                          |
+| Documentation                | See the [`albumentations` documentation](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Perspective) |
+| Examples                     | ![](../../assets/augmentations/line_perspective.png) ![](../../assets/augmentations/document_perspective.png)                                                        |
+| CPU time (seconds/10 images) | 0.05 (3013x128 pixels) / 0.05 (1116x581 pixels)                                                                                                                      |
+
+### Shearing (x-axis)
+
+|                              | Shearing (x-axis)                                                                                                                                                                   |
+| ---------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Description                  | This transformation changes the slant of the text on the image.                                                                                                                     |
+| Comments                     | New transform that was not in the original DAN implementation.                                                                                                                      |
+| Documentation                | See the [`albumentations` documentation](https://albumentations.ai/docs/api_reference/augmentations/geometric/transforms/#albumentations.augmentations.geometric.transforms.Affine) |
+| Examples                     | ![](../../assets/augmentations/line_shearx.png) ![](../../assets/augmentations/document_shearx.png)                                                                                 |
+| CPU time (seconds/10 images) | 0.05 (3013x128 pixels) / 0.04 (1116x581 pixels)                                                                                                                                     |
+
+### Coarse Dropout
+
+|                              | Coarse Dropout                                                                                                                                                                             |
+| ---------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
+| Description                  | This transformation adds dropout on the image, turning small patches into black pixels.                                                                                                    |
+| Comments                     | It is a new transform that was not in the original DAN implementation.                                                                                                                     |
+| Documentation                | See the [`albumentations` documentation](https://albumentations.ai/docs/api_reference/augmentations/dropout/coarse_dropout/#coarsedropout-augmentation-augmentationsdropoutcoarse_dropout) |
+| Examples                     | ![](../../assets/augmentations/line_dropout.png) ![](../../assets/augmentations/document_dropout.png)                                                                                      |
+| CPU time (seconds/10 images) | 0.02 (3013x128 pixels) / 0.02 (1116x581 pixels)                                                                                                                                            |
+
+### Downscale
+
+|                              | Downscale                                                                                                                                                          |
+| ---------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
+| Description                  | This transformation downscales the image from a random factor.                                                                                                     |
+| Comments                     | It is a new transform that was not in the original DAN implementation.                                                                                             |
+| Documentation                | See the [`albumentations` documentation](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Downscale) |
+| Examples                     | ![](../../assets/augmentations/line_downscale.png) ![](../../assets/augmentations/document_downscale.png)                                                          |
+| CPU time (seconds/10 images) | 0.03 (3013x128 pixels) / 0.03 (1116x581 pixels)                                                                                                                    |
+
+### Grayscale
+
+|                              | Grayscale                                                                                                                                                       |
+| ---------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Description                  | This transformation transforms an RGB image into grayscale.                                                                                                     |
+| Comments                     | It is a new transform that was not in the original DAN implementation.                                                                                          |
+| Documentation                | See the [`albumentations` documentation](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ToGray) |
+| Examples                     | ![](../../assets/augmentations/line_grayscale.png) ![](../../assets/augmentations/document_grayscale.png)                                                       |
+| CPU time (seconds/10 images) | 0.02 (3013x128 pixels) / 0.02 (1116x581 pixels)                                                                                                                 |
+
+## Full augmentation pipeline
+
+* Data augmentation is applied with a probability of 0.9.
+* In this case, two transformation are randomly selected to be applied.
+*  `ElasticTransform` and `PieceWiseAffine` cannot be applied on the same image.
+* Reproducibility is possible by setting `random.seed` and `np.random.seed` (already done in `dan/ocr/document/train.py`)
+* Examples with new pipeline:
+
+![](../../assets/augmentations/line_full_pipeline.png)
+![](../../assets/augmentations/document_full_pipeline.png)
+![](../../assets/augmentations/document_full_pipeline_2.png)
diff --git a/docs/usage/train/index.md b/docs/usage/train/index.md
index 62a5fe0ee9326a45dc26e389262779d5087b37c5..2627b06902f5f353e7a05eb497550767edb5925d 100644
--- a/docs/usage/train/index.md
+++ b/docs/usage/train/index.md
@@ -21,3 +21,4 @@ To train DAN on lines, run `teklia-dan train document` with a line dataset.
 ## Additional page
 
 * [Jean Zay tutorial](jeanzay.md)
+* [Data augmentation](augmentation.md)
diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md
index e729c82afd66a5a04bce25a0a7c0b0186ffd6634..3a92f667b9de60a480df7c6a79a0f4b6fa9b732c 100644
--- a/docs/usage/train/parameters.md
+++ b/docs/usage/train/parameters.md
@@ -3,65 +3,74 @@ All hyperparameters are specified and editable in the training scripts (meaning
 
 ## Dataset parameters
 
-| Parameter                               | Description                                                                            | Type         | Default                                        |
-| --------------------------------------- | -------------------------------------------------------------------------------------- | ------------ | ---------------------------------------------- |
-| `dataset_name`                          | Name of the dataset.                                                                   | `str`        |                                                |
-| `dataset_level`                         | Level of the dataset. Should be named after the element type.                          | `str`        |                                                |
-| `dataset_variant`                       | Variant of the dataset. Usually empty for HTR datasets, `"_sem"` for HTR+NER datasets. | `str`        |                                                |
-| `dataset_path`                          | Path to the dataset.                                                                   | `str`        |                                                |
-| `dataset_params.config.dataset_manager` | Dataset manager class.                                                                 | custom class | `OCRDatasetManager`                            |
-| `dataset_params.config.dataset_class`   | Dataset class.                                                                         | custom class | `OCRDataset`                                   |
-| `dataset_params.config.datasets`        | Dataset dictionary with the dataset name as key and dataset path as value.             | `dict`       |                                                |
-| `dataset_params.config.load_in_memory`  | Load all images in CPU memory.                                                         | `str`        | `True`                                         |
-| `dataset_params.config.worker_per_gpu`  | Number of parallel processes per gpu for data loading.                                 | `int`        | `4`                                            |
-| `dataset_params.config.preprocessings`  | List of pre-processing functions to apply to input images.                             | `list`       | (see [dedicated section](#data-preprocessing)) |
-| `dataset_params.config.augmentation`    | Configuration for data augmentation.                                                   | `dict`       | (see [dedicated section](#data-augmentation))  |
+| Parameter                               | Description                                                                            | Type         | Default                                              |
+| --------------------------------------- | -------------------------------------------------------------------------------------- | ------------ | ---------------------------------------------------- |
+| `dataset_name`                          | Name of the dataset.                                                                   | `str`        |                                                      |
+| `dataset_level`                         | Level of the dataset. Should be named after the element type.                          | `str`        |                                                      |
+| `dataset_variant`                       | Variant of the dataset. Usually empty for HTR datasets, `"_sem"` for HTR+NER datasets. | `str`        |                                                      |
+| `dataset_path`                          | Path to the dataset.                                                                   | `str`        |                                                      |
+| `dataset_params.config.dataset_manager` | Dataset manager class.                                                                 | custom class | `OCRDatasetManager`                                  |
+| `dataset_params.config.dataset_class`   | Dataset class.                                                                         | custom class | `OCRDataset`                                         |
+| `dataset_params.config.datasets`        | Dataset dictionary with the dataset name as key and dataset path as value.             | `dict`       |                                                      |
+| `dataset_params.config.load_in_memory`  | Load all images in CPU memory.                                                         | `str`        | `True`                                               |
+| `dataset_params.config.worker_per_gpu`  | Number of parallel processes per gpu for data loading.                                 | `int`        | `4`                                                  |
+| `dataset_params.config.preprocessings`  | List of pre-processing functions to apply to input images.                             | `list`       | (see [dedicated section](#data-preprocessing))       |
+| `dataset_params.config.augmentation`    | Whether to use data augmentation on the training set.                                  | `bool`       | `True` (see [dedicated section](#data-augmentation)) |
 
 
 ### Data preprocessing
 
-Preprocessing is applied before training the network (see `dan/manager/dataset.py`).
-The following transformations are implemented:
+Preprocessing is applied before training the network (see `dan/manager/dataset.py`). The list of accepted transforms is defined in `dan/transforms.py`:
 
-* Convert to grayscale
 ```py
-{
-    "type": "to_grayscaled"
-}
+class Preprocessing(Enum):
+    # If the image is bigger than the given size, resize it while keeping the original ratio
+    MaxResize = "max_resize"
+    # Resize the height to a fixed value while keeping the original ratio
+    FixedHeightResize = "fixed_height_resize"
+    # Resize the width to a fixed value while keeping the original ratio
+    FixedWidthResize = "fixed_width_resize"
 ```
-* Convert to RGB
+
+Usage:
+* Resize to a fixed height
 ```py
-{
-    "type": "to_RGB"
-}
+[
+    {
+        "type": Preprocessing.FixedHeightResize,
+        "fixed_height": 1500,
+    }
+]
 ```
-* Resize to a fixed height
+* Resize to a fixed width
 ```py
-{
-    "type": "fixed_height",
-    "fixed_height": 1000,
-}
+[
+    {
+        "type": Preprocessing.FixedWidthResize,
+        "fixed_height": 1500,
+    }
+]
 ```
-* Resize to a maximum size
+* Resize to a maximum size (only if the image is bigger than the given size)
 ```py
-{
-    "type": "resize",
-    "keep_ratio": True,
-    "max_height": 1000,
-    "max_width": None,
-}
+[
+    {
+        "type": Preprocessing.MaxResize,
+        "max_height": 2000,
+        "max_width": 2000,
+    }
+]
 ```
-
-Multiple transformations can be combined. For example, to resize an image to a fixed height of 1000 pixels and convert images to RGB, use the following configuration in `dataset_params.config.preprocessings`:
-
+* Combine these pre-processing
 ```py
 [
     {
-        "type": "fixed_height",
-        "fixed_height": 1000
+        "type": Preprocessing.FixedHeightResize,
+        "fixed_height": 2000,
     },
     {
-        "type": "to_RGB"
+        "type": Preprocessing.FixedWidthResize,
+        "fixed_width": 2000,
     }
 ]
 ```
@@ -70,91 +79,44 @@ Multiple transformations can be combined. For example, to resize an image to a f
 
 Augmentation transformations are applied on-the-fly during training to artificially increase data variability.
 
-The following transformations are implemented in `dan/transforms.py`:
-* Color inversion
-* DPI adjusting
-* Dilation and erosion
-* Elastic distortion
-* Reducing interline spacing
-* Gaussian blur
-* Gaussian noise
-
-DAN also takes advantage of [transforms from torchvision](https://pytorch.org/vision/stable/transforms.html):
-* ColorJitter
-* GaussianBlur
-* RandomCrop
-* RandomPerspective
-
-The following configuration is used by default when using the `teklia-dan train document` command. Data augmentation is applied with a probability of 0.9, and each transformation has a 0.1 probability to be used.
+DAN takes advantage of transforms from [albumentations](https://albumentations.ai/).
+The following configuration is used by default when using the `teklia-dan train document` command. Data augmentation is applied with a probability of 0.9. In this case, two transformations are randomly selected to be applied.
 
 ```py
-{
-        "order": "random",
-        "proba": 0.9,
-        "augmentations": [
-            {
-                "type": "dpi",
-                "proba": 0.1,
-                "min_factor": 0.75,
-                "max_factor": 1,
-                "preserve_ratio": True,
-            },
-            {
-                "type": "perspective",
-                "proba": 0.1,
-                "min_factor": 0,
-                "max_factor": 0.4,
-            },
-            {
-                "type": "elastic_distortion",
-                "proba": 0.1,
-                "min_alpha": 0.5,
-                "max_alpha": 1,
-                "min_sigma": 1,
-                "max_sigma": 10,
-                "min_kernel_size": 3,
-                "max_kernel_size": 9,
-            },
-            {
-                "type": "dilation_erosion",
-                "proba": 0.1,
-                "min_kernel": 1,
-                "max_kernel": 3,
-                "iterations": 1,
-            },
-            {
-                "type": "color_jittering",
-                "proba": 0.1,
-                "factor_hue": 0.2,
-                "factor_brightness": 0.4,
-                "factor_contrast": 0.4,
-                "factor_saturation": 0.4,
-            },
-            {
-                "type": "gaussian_blur",
-                "proba": 0.1,
-                "min_kernel": 3,
-                "max_kernel": 5,
-                "min_sigma": 3,
-                "max_sigma": 5,
-            },
-            {
-                "type": "gaussian_noise",
-                "proba": 0.1,
-                "std": 0.5,
-            },
-            {
-                "type": "sharpen",
-                "proba": 0.1,
-                "min_alpha": 0,
-                "max_alpha": 1,
-                "min_strength": 0,
-                "max_strength": 1,
-            },
-        ],
-    }
+transforms = A.Compose(
+        A.SomeOf(
+            [
+                DPIAdjusting(min_factor=0.75, max_factor=1),
+                Perspective(scale=(0.05, 0.09), fit_output=True),
+                GaussianBlur(sigma_limit=2.5),
+                GaussNoise(var_limit=50**2),
+                ColorJitter(contrast=0.2, brightness=0.2, saturation=0.2, hue=0.2),
+                A.OneOf(
+                    [
+                        ElasticTransform(
+                            alpha=20.0,
+                            sigma=5.0,
+                            alpha_affine=1.0,
+                            border_mode=0,
+                        ),
+                        PiecewiseAffine(scale=(0.01, 0.04), nb_rows=1, nb_cols=4),
+                    ]
+                ),
+                Sharpen(alpha=(0.0, 1.0)),
+                ErosionDilation(min_kernel=1, max_kernel=4, iterations=1),
+                Affine(shear={"x": (-20, 20), "y": (0, 0)}),
+                CoarseDropout(),
+                Downscale(scale_min=0.5, scale_max=0.9),
+                ToGray(),
+            ],
+            n=2,
+        ),
+        p=0.9,
+    )
 ```
 
+For a detailed description of all augmentation transforms, see the [dedicated page](augmentation.md).
+
 ## Model parameters
 
 | Name                                      | Description                                                                          | Type          | Default                                                           |
@@ -188,32 +150,32 @@ The following configuration is used by default when using the `teklia-dan train
 
 ## Training parameters
 
-| Name                                                        | Description                                                                 | Type         | Default                                     |
-| ----------------------------------------------------------- | --------------------------------------------------------------------------- | ------------ | ------------------------------------------- |
-| `training_params.output_folder`                             | Directory for checkpoint and results.                                       | `str`        |                                             |
-| `training_params.max_nb_epochs`                             | Maximum number of epochs before stopping training.                          | `int`        | `800`                                       |
-| `training_params.max_training_time`                         | Maximum time (in seconds) before stopping training.                         | `int`        | `350000`                                    |
-| `training_params.load_epoch`                                | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str`        | `"last"`                                    |
-| `training_params.interval_save_weights`                     | Step to save weights. Set to `None` to keep only best and last epochs.      | `int`        | `None`                                      |
-| `training_params.batch_size`                                | Mini-batch size for the training loop.                                      | `int`        | `2`                                         |
-| `training_params.valid_batch_size`                          | Mini-batch size for the valdiation loop.                                    | `int`        | `4`                                         |
-| `training_params.use_ddp`                                   | Whether to use DistributedDataParallel.                                     | `bool`       | `False`                                     |
-| `training_params.ddp_port`                                  | DDP port.                                                                   | `int`        | `20027`                                     |
-| `training_params.use_amp`                                   | Whether to enable automatic mix-precision.                                  | `int`        | `torch.cuda.device_count()`                 |
-| `training_params.nb_gpu`                                    | Number of GPUs to train DAN.                                                | `str`        |                                             |
-| `training_params.optimizers.all.class`                      | Optimizer class.                                                            | custom class | `Adam`                                      |
-| `training_params.optimizers.all.args.lr`                    | Learning rate for the optimizer.                                            | `float`      | `0.0001`                                    |
-| `training_params.optimizers.all.args.amsgrad`               | Whether to use AMSGrad optimization.                                        | custom class | `False`                                     |
-| `training_params.lr_schedulers`                             | Learning rate schedulers.                                                   | custom class | `None`                                      |
-| `training_params.eval_on_valid`                             | Whether to evaluate and log metrics on the validation set during training.  | `bool`       | `True`                                      |
-| `training_params.eval_on_valid_interval`                    | Interval (in epochs) to evaluate during training.                           | `int`        | `5`                                         |
-| `training_params.focus_metric`                              | Metrics to focus on to determine best epoch.                                | `str`        | `cer`                                       |
-| `training_params.expected_metric_value`                     | Best value for the focus metric. Should be either `"high"` or `"low"`.      | `low`        | `cer`                                       |
-| `training_params.set_name_focus_metric`                     | Dataset to focus on to select best weights.                                 | `str`        |                                             |
-| `training_params.train_metrics`                             | List of metrics to compute during training.                                 | `list`       | `["loss_ce", "cer", "wer", "wer_no_punct"]` |
-| `training_params.train_metrics`                             | List of metrics to compute during validation.                               | `list`       | `["cer", "wer", "wer_no_punct"]`            |
-| `training_params.force_cpu`                                 | Whether to train on CPU (for debugging).                                    | `bool`       | `False`                                     |
-| `training_params.max_char_prediction`                       | Maximum number of characters to predict.                                    | `int`        | `1000`                                      |
+| Name                                                    | Description                                                                 | Type         | Default                                     |
+| ------------------------------------------------------- | --------------------------------------------------------------------------- | ------------ | ------------------------------------------- |
+| `training_params.output_folder`                         | Directory for checkpoint and results.                                       | `str`        |                                             |
+| `training_params.max_nb_epochs`                         | Maximum number of epochs before stopping training.                          | `int`        | `800`                                       |
+| `training_params.max_training_time`                     | Maximum time (in seconds) before stopping training.                         | `int`        | `350000`                                    |
+| `training_params.load_epoch`                            | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str`        | `"last"`                                    |
+| `training_params.interval_save_weights`                 | Step to save weights. Set to `None` to keep only best and last epochs.      | `int`        | `None`                                      |
+| `training_params.batch_size`                            | Mini-batch size for the training loop.                                      | `int`        | `2`                                         |
+| `training_params.valid_batch_size`                      | Mini-batch size for the valdiation loop.                                    | `int`        | `4`                                         |
+| `training_params.use_ddp`                               | Whether to use DistributedDataParallel.                                     | `bool`       | `False`                                     |
+| `training_params.ddp_port`                              | DDP port.                                                                   | `int`        | `20027`                                     |
+| `training_params.use_amp`                               | Whether to enable automatic mix-precision.                                  | `int`        | `torch.cuda.device_count()`                 |
+| `training_params.nb_gpu`                                | Number of GPUs to train DAN.                                                | `str`        |                                             |
+| `training_params.optimizers.all.class`                  | Optimizer class.                                                            | custom class | `Adam`                                      |
+| `training_params.optimizers.all.args.lr`                | Learning rate for the optimizer.                                            | `float`      | `0.0001`                                    |
+| `training_params.optimizers.all.args.amsgrad`           | Whether to use AMSGrad optimization.                                        | custom class | `False`                                     |
+| `training_params.lr_schedulers`                         | Learning rate schedulers.                                                   | custom class | `None`                                      |
+| `training_params.eval_on_valid`                         | Whether to evaluate and log metrics on the validation set during training.  | `bool`       | `True`                                      |
+| `training_params.eval_on_valid_interval`                | Interval (in epochs) to evaluate during training.                           | `int`        | `5`                                         |
+| `training_params.focus_metric`                          | Metrics to focus on to determine best epoch.                                | `str`        | `cer`                                       |
+| `training_params.expected_metric_value`                 | Best value for the focus metric. Should be either `"high"` or `"low"`.      | `low`        | `cer`                                       |
+| `training_params.set_name_focus_metric`                 | Dataset to focus on to select best weights.                                 | `str`        |                                             |
+| `training_params.train_metrics`                         | List of metrics to compute during training.                                 | `list`       | `["loss_ce", "cer", "wer", "wer_no_punct"]` |
+| `training_params.train_metrics`                         | List of metrics to compute during validation.                               | `list`       | `["cer", "wer", "wer_no_punct"]`            |
+| `training_params.force_cpu`                             | Whether to train on CPU (for debugging).                                    | `bool`       | `False`                                     |
+| `training_params.max_char_prediction`                   | Maximum number of characters to predict.                                    | `int`        | `1000`                                      |
 | `training_params.label_noise_scheduler.min_error_rate`  | Minimum ratio of teacher forcing.                                           | `float`      | `0.2`                                       |
 | `training_params.label_noise_scheduler.max_error_rate`  | Maximum ratio of teacher forcing.                                           | `float`      | `0.2`                                       |
 | `training_params.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing.                            | `float`      | `5e4`                                       |
diff --git a/mkdocs.yml b/mkdocs.yml
index 8aa69a2efaf9e80ca4a4e470ccc4798a2b56dcf3..3c7787d4cb9616024eb7203881731e841bf4ad30 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -62,6 +62,7 @@ nav:
     - Training:
       - usage/train/index.md
       - Parameters: usage/train/parameters.md
+      - Data augmentation: usage/train/augmentation.md
       - Jean Zay tutorial: usage/train/jeanzay.md
     - Predict: usage/predict.md
   - Documentation development: dev/build_docs.md
diff --git a/requirements.txt b/requirements.txt
index 39d00baf289bc73fa171cf313269e4a9822ed31a..0a38d4faabc41fbc9f0b7fd46e01865aa2f6cd08 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
+albumentations==1.3.1
 arkindex-export==0.1.3
 boto3==1.26.124
 editdistance==0.6.2
diff --git a/tests/conftest.py b/tests/conftest.py
index 3752dbf985c1a286a31ea01905a7e062595bcd3f..a38d6a01fa4282106e67d9193822a0e131704147 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -9,7 +9,7 @@ from torch.optim import Adam
 from dan.decoder import GlobalHTADecoder
 from dan.encoder import FCN_Encoder
 from dan.schedulers import exponential_dropout_scheduler
-from dan.transforms import aug_config
+from dan.transforms import Preprocessing
 
 FIXTURES = Path(__file__).resolve().parent / "data"
 
@@ -70,11 +70,12 @@ def training_config():
                 "load_in_memory": True,  # Load all images in CPU memory
                 "preprocessings": [
                     {
-                        "type": "to_RGB",
-                        # if grayscaled image, produce RGB one (3 channels with same value) otherwise do nothing
+                        "type": Preprocessing.MaxResize,
+                        "max_width": 2000,
+                        "max_height": 2000,
                     },
                 ],
-                "augmentation": aug_config(0.9, 0.1),
+                "augmentation": True,
             },
         },
         "model_params": {
diff --git a/tests/data/training/models/best_0.pt b/tests/data/training/models/best_0.pt
index 1bf2702bbdea7aa93d14d4cb14027a4fed451d97..9190c418d96dac5aafbaa8d466314ef305ed46e7 100644
--- a/tests/data/training/models/best_0.pt
+++ b/tests/data/training/models/best_0.pt
@@ -1,3 +1,7 @@
 version https://git-lfs.github.com/spec/v1
+<<<<<<< HEAD
 oid sha256:3199e188056836ee2b907319c72c24abfb3b83d850dad8951033a63effa89e72
+=======
+oid sha256:81d9481289aa52a6c0b9b20a487881b9e5efee2e3559b6fd2873d0c4e15ae9ba
+>>>>>>> 78e6f60 (Use torchvision functions / transforms for data augmentation)
 size 84773087
diff --git a/tests/data/training/models/last_3.pt b/tests/data/training/models/last_3.pt
index 27e32d82a4f4e5533844c8208cce7b482e922e6c..5c82fb729e19c5e55686b0e675a1a3beec4c7071 100644
--- a/tests/data/training/models/last_3.pt
+++ b/tests/data/training/models/last_3.pt
@@ -1,3 +1,8 @@
 version https://git-lfs.github.com/spec/v1
+<<<<<<< HEAD
 oid sha256:c62f5090b1ae30e55a758a4d3ce9814c754f336b241438e106a98c630e0c31e6
 size 84773279
+=======
+oid sha256:e2917e210f1c78b5abbd32b5d14c1c8157dc28fc6b20369bbfda21e9dc50cfd6
+size 84773087
+>>>>>>> 78e6f60 (Use torchvision functions / transforms for data augmentation)
diff --git a/tests/test_training.py b/tests/test_training.py
index b2c4ebbb33c702e0134f654f7dea1474fe88e813..630e9506330d82998e00d5b628c5a2281eabe242 100644
--- a/tests/test_training.py
+++ b/tests/test_training.py
@@ -25,7 +25,7 @@ from tests.conftest import FIXTURES
             },
             {
                 "nb_chars": 41,
-                "cer": 1.2683,
+                "cer": 1.2927,
                 "nb_words": 9,
                 "wer": 1.0,
                 "nb_words_no_punct": 9,
@@ -34,7 +34,7 @@ from tests.conftest import FIXTURES
             },
             {
                 "nb_chars": 49,
-                "cer": 1.1429,
+                "cer": 1.102,
                 "nb_words": 9,
                 "wer": 1.0,
                 "nb_words_no_punct": 9,