Skip to content
Snippets Groups Projects
Commit 84bc0d4b authored by Mélodie Boillet's avatar Mélodie Boillet Committed by Marie Generali
Browse files

Simplify mean and std computation

parent 2066e615
No related branches found
No related tags found
No related merge requests found
...@@ -268,7 +268,10 @@ class GenericDataset(Dataset): ...@@ -268,7 +268,10 @@ class GenericDataset(Dataset):
def apply_preprocessing(self, preprocessings): def apply_preprocessing(self, preprocessings):
for i in range(len(self.samples)): for i in range(len(self.samples)):
self.samples[i] = apply_preprocessing(self.samples[i], preprocessings) (
self.samples[i]["img"],
self.samples[i]["resize_ratio"],
) = apply_preprocessing(self.samples[i]["img"], preprocessings)
def compute_std_mean(self): def compute_std_mean(self):
""" """
...@@ -276,46 +279,33 @@ class GenericDataset(Dataset): ...@@ -276,46 +279,33 @@ class GenericDataset(Dataset):
""" """
if self.mean is not None and self.std is not None: if self.mean is not None and self.std is not None:
return self.mean, self.std return self.mean, self.std
if not self.load_in_memory:
sample = self.samples[0].copy() sum = np.zeros((3,))
sample["img"] = self.get_sample_img(0) diff = np.zeros((3,))
img = apply_preprocessing(sample, self.params["config"]["preprocessings"])[
"img"
]
else:
img = self.get_sample_img(0)
_, _, c = img.shape
sum = np.zeros((c,))
nb_pixels = 0 nb_pixels = 0
for metric in ["mean", "std"]:
for ind in range(len(self.samples)):
img = (
self.get_sample_img(ind)
if self.load_in_memory
else apply_preprocessing(
self.get_sample_img(ind),
self.params["config"]["preprocessings"],
)[0]
)
for i in range(len(self.samples)): if metric == "mean":
if not self.load_in_memory: sum += np.sum(img, axis=(0, 1))
sample = self.samples[i].copy() nb_pixels += np.prod(img.shape[:2])
sample["img"] = self.get_sample_img(i) elif metric == "std":
img = apply_preprocessing( diff += [
sample, self.params["config"]["preprocessings"] np.sum((img[:, :, k] - self.mean[k]) ** 2) for k in range(3)
)["img"] ]
else: if metric == "mean":
img = self.get_sample_img(i) self.mean = sum / nb_pixels
sum += np.sum(img, axis=(0, 1)) elif metric == "std":
nb_pixels += np.prod(img.shape[:2]) self.std = np.sqrt(diff / nb_pixels)
mean = sum / nb_pixels return self.mean, self.std
diff = np.zeros((c,))
for i in range(len(self.samples)):
if not self.load_in_memory:
sample = self.samples[i].copy()
sample["img"] = self.get_sample_img(i)
img = apply_preprocessing(
sample, self.params["config"]["preprocessings"]
)["img"]
else:
img = self.get_sample_img(i)
diff += [np.sum((img[:, :, k] - mean[k]) ** 2) for k in range(c)]
std = np.sqrt(diff / nb_pixels)
self.mean = mean
self.std = std
return mean, std
def apply_data_augmentation(self, img): def apply_data_augmentation(self, img):
""" """
...@@ -340,12 +330,11 @@ class GenericDataset(Dataset): ...@@ -340,12 +330,11 @@ class GenericDataset(Dataset):
return GenericDataset.load_image(self.samples[i]["path"]) return GenericDataset.load_image(self.samples[i]["path"])
def apply_preprocessing(sample, preprocessings): def apply_preprocessing(img, preprocessings):
""" """
Apply preprocessings on each sample Apply preprocessings on an image
""" """
resize_ratio = [1, 1] resize_ratio = [1, 1]
img = sample["img"]
for preprocessing in preprocessings: for preprocessing in preprocessings:
if preprocessing["type"] == "to_grayscaled": if preprocessing["type"] == "to_grayscaled":
temp_img = img temp_img = img
...@@ -394,6 +383,4 @@ def apply_preprocessing(sample, preprocessings): ...@@ -394,6 +383,4 @@ def apply_preprocessing(sample, preprocessings):
img = temp_img img = temp_img
resize_ratio = [ratio, ratio] resize_ratio = [ratio, ratio]
sample["img"] = img return img, resize_ratio
sample["resize_ratio"] = resize_ratio
return sample
...@@ -66,8 +66,8 @@ class OCRDataset(GenericDataset): ...@@ -66,8 +66,8 @@ class OCRDataset(GenericDataset):
if not self.load_in_memory: if not self.load_in_memory:
sample["img"] = self.get_sample_img(idx) sample["img"] = self.get_sample_img(idx)
sample = apply_preprocessing( sample["img"], sample["resize_ratio"] = apply_preprocessing(
sample, self.params["config"]["preprocessings"] sample["img"], self.params["config"]["preprocessings"]
) )
# Data augmentation # Data augmentation
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment