Skip to content
Snippets Groups Projects
Verified Commit 71d1406c authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Simplify mean and std computation

parent 8a0ee85b
No related branches found
No related tags found
1 merge request!172Simplify mean and std computation
......@@ -268,7 +268,10 @@ class GenericDataset(Dataset):
def apply_preprocessing(self, preprocessings):
for i in range(len(self.samples)):
self.samples[i] = apply_preprocessing(self.samples[i], preprocessings)
(
self.samples[i]["img"],
self.samples[i]["resize_ratio"],
) = apply_preprocessing(self.samples[i]["img"], preprocessings)
def compute_std_mean(self):
"""
......@@ -276,46 +279,33 @@ class GenericDataset(Dataset):
"""
if self.mean is not None and self.std is not None:
return self.mean, self.std
if not self.load_in_memory:
sample = self.samples[0].copy()
sample["img"] = self.get_sample_img(0)
img = apply_preprocessing(sample, self.params["config"]["preprocessings"])[
"img"
]
else:
img = self.get_sample_img(0)
_, _, c = img.shape
sum = np.zeros((c,))
sum = np.zeros((3,))
diff = np.zeros((3,))
nb_pixels = 0
for metric in ["mean", "std"]:
for ind in range(len(self.samples)):
img = (
self.get_sample_img(ind)
if self.load_in_memory
else apply_preprocessing(
self.get_sample_img(ind),
self.params["config"]["preprocessings"],
)[0]
)
for i in range(len(self.samples)):
if not self.load_in_memory:
sample = self.samples[i].copy()
sample["img"] = self.get_sample_img(i)
img = apply_preprocessing(
sample, self.params["config"]["preprocessings"]
)["img"]
else:
img = self.get_sample_img(i)
sum += np.sum(img, axis=(0, 1))
nb_pixels += np.prod(img.shape[:2])
mean = sum / nb_pixels
diff = np.zeros((c,))
for i in range(len(self.samples)):
if not self.load_in_memory:
sample = self.samples[i].copy()
sample["img"] = self.get_sample_img(i)
img = apply_preprocessing(
sample, self.params["config"]["preprocessings"]
)["img"]
else:
img = self.get_sample_img(i)
diff += [np.sum((img[:, :, k] - mean[k]) ** 2) for k in range(c)]
std = np.sqrt(diff / nb_pixels)
self.mean = mean
self.std = std
return mean, std
if metric == "mean":
sum += np.sum(img, axis=(0, 1))
nb_pixels += np.prod(img.shape[:2])
elif metric == "std":
diff += [
np.sum((img[:, :, k] - self.mean[k]) ** 2) for k in range(3)
]
if metric == "mean":
self.mean = sum / nb_pixels
elif metric == "std":
self.std = np.sqrt(diff / nb_pixels)
return self.mean, self.std
def apply_data_augmentation(self, img):
"""
......@@ -340,12 +330,11 @@ class GenericDataset(Dataset):
return GenericDataset.load_image(self.samples[i]["path"])
def apply_preprocessing(sample, preprocessings):
def apply_preprocessing(img, preprocessings):
"""
Apply preprocessings on each sample
Apply preprocessings on an image
"""
resize_ratio = [1, 1]
img = sample["img"]
for preprocessing in preprocessings:
if preprocessing["type"] == "to_grayscaled":
temp_img = img
......@@ -394,6 +383,4 @@ def apply_preprocessing(sample, preprocessings):
img = temp_img
resize_ratio = [ratio, ratio]
sample["img"] = img
sample["resize_ratio"] = resize_ratio
return sample
return img, resize_ratio
......@@ -66,8 +66,8 @@ class OCRDataset(GenericDataset):
if not self.load_in_memory:
sample["img"] = self.get_sample_img(idx)
sample = apply_preprocessing(
sample, self.params["config"]["preprocessings"]
sample["img"], sample["resize_ratio"] = apply_preprocessing(
sample["img"], self.params["config"]["preprocessings"]
)
# Data augmentation
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment