diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py
index 547ada96bb57de6027adddc2b53e4060c370446b..819e90664ddc39ef410489fdecdfbefa7f1ecbfd 100644
--- a/dan/manager/dataset.py
+++ b/dan/manager/dataset.py
@@ -268,7 +268,10 @@ class GenericDataset(Dataset):
 
     def apply_preprocessing(self, preprocessings):
         for i in range(len(self.samples)):
-            self.samples[i] = apply_preprocessing(self.samples[i], preprocessings)
+            (
+                self.samples[i]["img"],
+                self.samples[i]["resize_ratio"],
+            ) = apply_preprocessing(self.samples[i]["img"], preprocessings)
 
     def compute_std_mean(self):
         """
@@ -276,46 +279,33 @@ class GenericDataset(Dataset):
         """
         if self.mean is not None and self.std is not None:
             return self.mean, self.std
-        if not self.load_in_memory:
-            sample = self.samples[0].copy()
-            sample["img"] = self.get_sample_img(0)
-            img = apply_preprocessing(sample, self.params["config"]["preprocessings"])[
-                "img"
-            ]
-        else:
-            img = self.get_sample_img(0)
-        _, _, c = img.shape
-        sum = np.zeros((c,))
+
+        sum = np.zeros((3,))
+        diff = np.zeros((3,))
         nb_pixels = 0
+        for metric in ["mean", "std"]:
+            for ind in range(len(self.samples)):
+                img = (
+                    self.get_sample_img(ind)
+                    if self.load_in_memory
+                    else apply_preprocessing(
+                        self.get_sample_img(ind),
+                        self.params["config"]["preprocessings"],
+                    )[0]
+                )
 
-        for i in range(len(self.samples)):
-            if not self.load_in_memory:
-                sample = self.samples[i].copy()
-                sample["img"] = self.get_sample_img(i)
-                img = apply_preprocessing(
-                    sample, self.params["config"]["preprocessings"]
-                )["img"]
-            else:
-                img = self.get_sample_img(i)
-            sum += np.sum(img, axis=(0, 1))
-            nb_pixels += np.prod(img.shape[:2])
-        mean = sum / nb_pixels
-        diff = np.zeros((c,))
-        for i in range(len(self.samples)):
-            if not self.load_in_memory:
-                sample = self.samples[i].copy()
-                sample["img"] = self.get_sample_img(i)
-                img = apply_preprocessing(
-                    sample, self.params["config"]["preprocessings"]
-                )["img"]
-            else:
-                img = self.get_sample_img(i)
-            diff += [np.sum((img[:, :, k] - mean[k]) ** 2) for k in range(c)]
-        std = np.sqrt(diff / nb_pixels)
-
-        self.mean = mean
-        self.std = std
-        return mean, std
+                if metric == "mean":
+                    sum += np.sum(img, axis=(0, 1))
+                    nb_pixels += np.prod(img.shape[:2])
+                elif metric == "std":
+                    diff += [
+                        np.sum((img[:, :, k] - self.mean[k]) ** 2) for k in range(3)
+                    ]
+            if metric == "mean":
+                self.mean = sum / nb_pixels
+            elif metric == "std":
+                self.std = np.sqrt(diff / nb_pixels)
+        return self.mean, self.std
 
     def apply_data_augmentation(self, img):
         """
@@ -340,12 +330,11 @@ class GenericDataset(Dataset):
             return GenericDataset.load_image(self.samples[i]["path"])
 
 
-def apply_preprocessing(sample, preprocessings):
+def apply_preprocessing(img, preprocessings):
     """
-    Apply preprocessings on each sample
+    Apply preprocessings on an image
     """
     resize_ratio = [1, 1]
-    img = sample["img"]
     for preprocessing in preprocessings:
         if preprocessing["type"] == "to_grayscaled":
             temp_img = img
@@ -394,6 +383,4 @@ def apply_preprocessing(sample, preprocessings):
             img = temp_img
             resize_ratio = [ratio, ratio]
 
-    sample["img"] = img
-    sample["resize_ratio"] = resize_ratio
-    return sample
+    return img, resize_ratio
diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index c60ce88cfb9f1503273bd81f32024b1e506ff359..da1db034b9e34a248159c3949b88e67a91788681 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -66,8 +66,8 @@ class OCRDataset(GenericDataset):
 
         if not self.load_in_memory:
             sample["img"] = self.get_sample_img(idx)
-            sample = apply_preprocessing(
-                sample, self.params["config"]["preprocessings"]
+            sample["img"], sample["resize_ratio"] = apply_preprocessing(
+                sample["img"], self.params["config"]["preprocessings"]
             )
 
         # Data augmentation