Skip to content
Snippets Groups Projects
Verified Commit 0ce6787a authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Remove list of applied data augmentations

parent 0f97e244
No related branches found
No related tags found
No related merge requests found
...@@ -336,7 +336,7 @@ class GenericDataset(Dataset): ...@@ -336,7 +336,7 @@ class GenericDataset(Dataset):
for aug, set_name in zip(augs, ["train", "val", "test"]): for aug, set_name in zip(augs, ["train", "val", "test"]):
if aug and self.set_name == set_name: if aug and self.set_name == set_name:
return apply_data_augmentation(img, aug) return apply_data_augmentation(img, aug)
return img, list() return img
def get_sample_img(self, i): def get_sample_img(self, i):
""" """
......
...@@ -136,9 +136,7 @@ class OCRDataset(GenericDataset): ...@@ -136,9 +136,7 @@ class OCRDataset(GenericDataset):
sample = self.generate_synthetic_data(sample) sample = self.generate_synthetic_data(sample)
# Data augmentation # Data augmentation
sample["img"], sample["applied_da"] = self.apply_data_augmentation( sample["img"] = self.apply_data_augmentation(sample["img"])
sample["img"]
)
if "max_size" in self.params["config"] and self.params["config"]["max_size"]: if "max_size" in self.params["config"] and self.params["config"]["max_size"]:
max_ratio = max( max_ratio = max(
...@@ -523,7 +521,6 @@ class OCRCollateFunction: ...@@ -523,7 +521,6 @@ class OCRCollateFunction:
batch_data[i]["name"].split("/")[-1].split(".")[0] batch_data[i]["name"].split("/")[-1].split(".")[0]
for i in range(len(batch_data)) for i in range(len(batch_data))
] ]
applied_da = [batch_data[i]["applied_da"] for i in range(len(batch_data))]
labels = [batch_data[i]["token_label"] for i in range(len(batch_data))] labels = [batch_data[i]["token_label"] for i in range(len(batch_data))]
labels = pad_sequences_1D(labels, padding_value=self.label_padding_value) labels = pad_sequences_1D(labels, padding_value=self.label_padding_value)
...@@ -630,7 +627,6 @@ class OCRCollateFunction: ...@@ -630,7 +627,6 @@ class OCRCollateFunction:
"word_raw": word_raw, "word_raw": word_raw,
"word_labels": pad_word_token, "word_labels": pad_word_token,
"word_labels_len": word_len, "word_labels_len": word_len,
"applied_da": applied_da,
} }
return formatted_batch_data return formatted_batch_data
......
...@@ -328,9 +328,8 @@ def apply_data_augmentation(img, da_config): ...@@ -328,9 +328,8 @@ def apply_data_augmentation(img, da_config):
""" """
Apply data augmentation strategy on input image Apply data augmentation strategy on input image
""" """
applied_da = list()
if da_config["proba"] != 1 and rand() > da_config["proba"]: if da_config["proba"] != 1 and rand() > da_config["proba"]:
return img, applied_da return img
# Convert to PIL Image # Convert to PIL Image
img = img[:, :, 0] if img.shape[2] == 1 else img img = img[:, :, 0] if img.shape[2] == 1 else img
...@@ -345,12 +344,11 @@ def apply_data_augmentation(img, da_config): ...@@ -345,12 +344,11 @@ def apply_data_augmentation(img, da_config):
for augmenter in augmenters: for augmenter in augmenters:
img = augmenter(img) img = augmenter(img)
applied_da.append(type(augmenter).__name__)
# convert to numpy array # convert to numpy array
img = np.array(img) img = np.array(img)
img = np.expand_dims(img, axis=2) if len(img.shape) == 2 else img img = np.expand_dims(img, axis=2) if len(img.shape) == 2 else img
return img, applied_da return img
def apply_transform(img, transform): def apply_transform(img, transform):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment