From 0ce6787ac105276abefa295b6f2eefc9b74ba43d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Wed, 24 May 2023 14:28:31 +0200
Subject: [PATCH] Remove list of applied data augmentations

---
 dan/manager/dataset.py | 2 +-
 dan/manager/ocr.py     | 6 +-----
 dan/transforms.py      | 6 ++----
 3 files changed, 4 insertions(+), 10 deletions(-)

diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py
index 2b28e6dc..942ba5c2 100644
--- a/dan/manager/dataset.py
+++ b/dan/manager/dataset.py
@@ -336,7 +336,7 @@ class GenericDataset(Dataset):
         for aug, set_name in zip(augs, ["train", "val", "test"]):
             if aug and self.set_name == set_name:
                 return apply_data_augmentation(img, aug)
-        return img, list()
+        return img
 
     def get_sample_img(self, i):
         """
diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index bd464073..0ca7ab3d 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -136,9 +136,7 @@ class OCRDataset(GenericDataset):
             sample = self.generate_synthetic_data(sample)
 
         # Data augmentation
-        sample["img"], sample["applied_da"] = self.apply_data_augmentation(
-            sample["img"]
-        )
+        sample["img"] = self.apply_data_augmentation(sample["img"])
 
         if "max_size" in self.params["config"] and self.params["config"]["max_size"]:
             max_ratio = max(
@@ -523,7 +521,6 @@ class OCRCollateFunction:
             batch_data[i]["name"].split("/")[-1].split(".")[0]
             for i in range(len(batch_data))
         ]
-        applied_da = [batch_data[i]["applied_da"] for i in range(len(batch_data))]
 
         labels = [batch_data[i]["token_label"] for i in range(len(batch_data))]
         labels = pad_sequences_1D(labels, padding_value=self.label_padding_value)
@@ -630,7 +627,6 @@ class OCRCollateFunction:
             "word_raw": word_raw,
             "word_labels": pad_word_token,
             "word_labels_len": word_len,
-            "applied_da": applied_da,
         }
 
         return formatted_batch_data
diff --git a/dan/transforms.py b/dan/transforms.py
index 89cb5e29..957a6453 100644
--- a/dan/transforms.py
+++ b/dan/transforms.py
@@ -328,9 +328,8 @@ def apply_data_augmentation(img, da_config):
     """
     Apply data augmentation strategy on input image
     """
-    applied_da = list()
     if da_config["proba"] != 1 and rand() > da_config["proba"]:
-        return img, applied_da
+        return img
 
     # Convert to PIL Image
     img = img[:, :, 0] if img.shape[2] == 1 else img
@@ -345,12 +344,11 @@ def apply_data_augmentation(img, da_config):
 
     for augmenter in augmenters:
         img = augmenter(img)
-        applied_da.append(type(augmenter).__name__)
 
     # convert to numpy array
     img = np.array(img)
     img = np.expand_dims(img, axis=2) if len(img.shape) == 2 else img
-    return img, applied_da
+    return img
 
 
 def apply_transform(img, transform):
-- 
GitLab