diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py index 2b28e6dc1428b57a0c247ef92b8ef6acb8b32377..942ba5c2a43292d392e4006e7831a250be97fc55 100644 --- a/dan/manager/dataset.py +++ b/dan/manager/dataset.py @@ -336,7 +336,7 @@ class GenericDataset(Dataset): for aug, set_name in zip(augs, ["train", "val", "test"]): if aug and self.set_name == set_name: return apply_data_augmentation(img, aug) - return img, list() + return img def get_sample_img(self, i): """ diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index bd4640731aafa905b05b596233e917e1971a7209..0ca7ab3d5beb4b9b0b06e9cbc744f8687c643731 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -136,9 +136,7 @@ class OCRDataset(GenericDataset): sample = self.generate_synthetic_data(sample) # Data augmentation - sample["img"], sample["applied_da"] = self.apply_data_augmentation( - sample["img"] - ) + sample["img"] = self.apply_data_augmentation(sample["img"]) if "max_size" in self.params["config"] and self.params["config"]["max_size"]: max_ratio = max( @@ -523,7 +521,6 @@ class OCRCollateFunction: batch_data[i]["name"].split("/")[-1].split(".")[0] for i in range(len(batch_data)) ] - applied_da = [batch_data[i]["applied_da"] for i in range(len(batch_data))] labels = [batch_data[i]["token_label"] for i in range(len(batch_data))] labels = pad_sequences_1D(labels, padding_value=self.label_padding_value) @@ -630,7 +627,6 @@ class OCRCollateFunction: "word_raw": word_raw, "word_labels": pad_word_token, "word_labels_len": word_len, - "applied_da": applied_da, } return formatted_batch_data diff --git a/dan/transforms.py b/dan/transforms.py index 89cb5e29eb3603bc62cc1c7e293aebbd1282c9b4..957a6453635df4043082539c0115d943d16f9ae1 100644 --- a/dan/transforms.py +++ b/dan/transforms.py @@ -328,9 +328,8 @@ def apply_data_augmentation(img, da_config): """ Apply data augmentation strategy on input image """ - applied_da = list() if da_config["proba"] != 1 and rand() > da_config["proba"]: - return img, applied_da + return img # Convert to PIL Image img = img[:, :, 0] if img.shape[2] == 1 else img @@ -345,12 +344,11 @@ def apply_data_augmentation(img, da_config): for augmenter in augmenters: img = augmenter(img) - applied_da.append(type(augmenter).__name__) # convert to numpy array img = np.array(img) img = np.expand_dims(img, axis=2) if len(img.shape) == 2 else img - return img, applied_da + return img def apply_transform(img, transform):