From 0ce6787ac105276abefa295b6f2eefc9b74ba43d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com> Date: Wed, 24 May 2023 14:28:31 +0200 Subject: [PATCH] Remove list of applied data augmentations --- dan/manager/dataset.py | 2 +- dan/manager/ocr.py | 6 +----- dan/transforms.py | 6 ++---- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py index 2b28e6dc..942ba5c2 100644 --- a/dan/manager/dataset.py +++ b/dan/manager/dataset.py @@ -336,7 +336,7 @@ class GenericDataset(Dataset): for aug, set_name in zip(augs, ["train", "val", "test"]): if aug and self.set_name == set_name: return apply_data_augmentation(img, aug) - return img, list() + return img def get_sample_img(self, i): """ diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index bd464073..0ca7ab3d 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -136,9 +136,7 @@ class OCRDataset(GenericDataset): sample = self.generate_synthetic_data(sample) # Data augmentation - sample["img"], sample["applied_da"] = self.apply_data_augmentation( - sample["img"] - ) + sample["img"] = self.apply_data_augmentation(sample["img"]) if "max_size" in self.params["config"] and self.params["config"]["max_size"]: max_ratio = max( @@ -523,7 +521,6 @@ class OCRCollateFunction: batch_data[i]["name"].split("/")[-1].split(".")[0] for i in range(len(batch_data)) ] - applied_da = [batch_data[i]["applied_da"] for i in range(len(batch_data))] labels = [batch_data[i]["token_label"] for i in range(len(batch_data))] labels = pad_sequences_1D(labels, padding_value=self.label_padding_value) @@ -630,7 +627,6 @@ class OCRCollateFunction: "word_raw": word_raw, "word_labels": pad_word_token, "word_labels_len": word_len, - "applied_da": applied_da, } return formatted_batch_data diff --git a/dan/transforms.py b/dan/transforms.py index 89cb5e29..957a6453 100644 --- a/dan/transforms.py +++ b/dan/transforms.py @@ -328,9 +328,8 @@ def apply_data_augmentation(img, da_config): """ Apply data augmentation strategy on input image """ - applied_da = list() if da_config["proba"] != 1 and rand() > da_config["proba"]: - return img, applied_da + return img # Convert to PIL Image img = img[:, :, 0] if img.shape[2] == 1 else img @@ -345,12 +344,11 @@ def apply_data_augmentation(img, da_config): for augmenter in augmenters: img = augmenter(img) - applied_da.append(type(augmenter).__name__) # convert to numpy array img = np.array(img) img = np.expand_dims(img, axis=2) if len(img.shape) == 2 else img - return img, applied_da + return img def apply_transform(img, transform): -- GitLab