From c2eebea38fc261cce115dc4002172bb2215deb51 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Fri, 18 Aug 2023 09:15:49 +0000
Subject: [PATCH] Reduce memory usage during training

---
 dan/ocr/decoder.py            |  9 +++--
 dan/ocr/manager/ocr.py        | 42 +++++++++++-----------
 dan/ocr/manager/training.py   | 65 +++++++++++++++--------------------
 dan/ocr/predict/prediction.py | 11 ++----
 dan/ocr/transforms.py         | 21 ++++-------
 5 files changed, 60 insertions(+), 88 deletions(-)

diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py
index 69e372cf..ccf2b4d6 100644
--- a/dan/ocr/decoder.py
+++ b/dan/ocr/decoder.py
@@ -337,8 +337,7 @@ class GlobalHTADecoder(Module):
 
     def forward(
         self,
-        raw_features_1d,
-        enhanced_features_1d,
+        features_1d,
         tokens,
         reduced_size,
         token_len,
@@ -349,7 +348,7 @@ class GlobalHTADecoder(Module):
         num_pred=None,
         keep_all_weights=False,
     ):
-        device = raw_features_1d.device
+        device = features_1d.device
 
         # Token to Embedding
         pos_tokens = self.emb(tokens).permute(0, 2, 1)
@@ -393,8 +392,8 @@ class GlobalHTADecoder(Module):
 
         output, weights, cache = self.att_decoder(
             pos_tokens,
-            memory_key=enhanced_features_1d,
-            memory_value=raw_features_1d,
+            memory_key=features_1d,
+            memory_value=features_1d,
             tgt_mask=target_mask,
             memory_mask=memory_mask,
             tgt_key_padding_mask=key_target_mask,
diff --git a/dan/ocr/manager/ocr.py b/dan/ocr/manager/ocr.py
index 5e5724d7..00262d9b 100644
--- a/dan/ocr/manager/ocr.py
+++ b/dan/ocr/manager/ocr.py
@@ -246,34 +246,32 @@ class OCRCollateFunction:
         self.label_padding_value = padding_token
 
     def __call__(self, batch_data):
-        labels = [batch_data[i]["token_label"] for i in range(len(batch_data))]
-        labels = pad_sequences_1D(labels, padding_value=self.label_padding_value).long()
-
-        imgs = [
-            torch.from_numpy(batch_data[i]["img"]).permute(2, 0, 1)
-            for i in range(len(batch_data))
-        ]
-        imgs = pad_images(imgs)
-
         formatted_batch_data = {
-            formatted_key: [batch_data[i][initial_key] for i in range(len(batch_data))]
-            for formatted_key, initial_key in zip(
+            "imgs": pad_images(
                 [
-                    "names",
-                    "labels_len",
-                    "raw_labels",
-                    "imgs_position",
-                    "imgs_reduced_shape",
-                ],
-                ["name", "label_len", "label", "img_position", "img_reduced_shape"],
-            )
+                    torch.from_numpy(sample["img"]).permute(2, 0, 1)
+                    for sample in batch_data
+                ]
+            ),
+            "labels": pad_sequences_1D(
+                [sample["token_label"] for sample in batch_data],
+                padding_value=self.label_padding_value,
+            ).long(),
         }
 
         formatted_batch_data.update(
             {
-                "imgs": imgs,
-                "labels": labels,
+                formatted_key: [sample[initial_key] for sample in batch_data]
+                for formatted_key, initial_key in zip(
+                    [
+                        "names",
+                        "imgs_position",
+                        "imgs_reduced_shape",
+                        "labels_len",
+                        "raw_labels",
+                    ],
+                    ["name", "img_position", "img_reduced_shape", "label_len", "label"],
+                )
             }
         )
-
         return formatted_batch_data
diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py
index 2b13c515..e1c9871d 100644
--- a/dan/ocr/manager/training.py
+++ b/dan/ocr/manager/training.py
@@ -634,7 +634,7 @@ class GenericTrainingManager:
                                 )
                                 if "lr" in metric_names:
                                     self.writer.add_scalar(
-                                        "lr_{}".format(model_name),
+                                        "lr/{}".format(model_name),
                                     )
 
                     # Update dropout scheduler
@@ -656,7 +656,9 @@ class GenericTrainingManager:
                 # log metrics in tensorboard file
                 for key in display_values:
                     self.writer.add_scalar(
-                        "{}_{}".format(self.params["dataset"]["train"]["name"], key),
+                        "train/{}_{}".format(
+                            self.params["dataset"]["train"]["name"], key
+                        ),
                         display_values[key],
                         num_epoch,
                     )
@@ -677,7 +679,7 @@ class GenericTrainingManager:
                     if self.is_master:
                         for key in eval_values:
                             self.writer.add_scalar(
-                                "{}_{}".format(valid_set_name, key),
+                                "valid/{}_{}".format(valid_set_name, key),
                                 eval_values[key],
                                 num_epoch,
                             )
@@ -720,7 +722,7 @@ class GenericTrainingManager:
             tokens=self.tokens,
         )
         with tqdm(total=len(loader.dataset)) as pbar:
-            pbar.set_description("Validation E{}".format(self.latest_epoch))
+            pbar.set_description("VALID {} - {}".format(self.latest_epoch, set_name))
             with torch.no_grad():
                 # iterate over batch data
                 for ind_batch, batch_data in enumerate(loader):
@@ -945,9 +947,8 @@ class Manager(GenericTrainingManager):
         loss_func = CrossEntropyLoss(ignore_index=self.dataset.tokens["pad"])
 
         sum_loss = 0
-        x = batch_data["imgs"].to(self.device)
-        y = batch_data["labels"].to(self.device)
-        reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]]
+        b = batch_data["imgs"].shape[0]
+        batch_data["labels"] = batch_data["labels"].to(self.device)
         y_len = batch_data["labels_len"]
 
         if "label_noise_scheduler" in self.params["training"]:
@@ -963,38 +964,33 @@ class Manager(GenericTrainingManager):
                 )
                 / self.params["training"]["label_noise_scheduler"]["total_num_steps"]
             )
-            simulated_y_pred, y_len = self.add_label_noise(y, y_len, error_rate)
+            simulated_y_pred, y_len = self.add_label_noise(
+                batch_data["labels"], y_len, error_rate
+            )
         else:
-            simulated_y_pred = y
+            simulated_y_pred = batch_data["labels"]
 
         with autocast(enabled=self.device_params["use_amp"]):
             hidden_predict = None
             cache = None
 
-            raw_features = self.models["encoder"](x)
-            features_size = raw_features.size()
-            b, c, h, w = features_size
+            features = self.models["encoder"](batch_data["imgs"].to(self.device))
+            features_size = features.size()
 
             if self.device_params["use_ddp"]:
-                pos_features = self.models[
+                features = self.models[
                     "decoder"
-                ].module.features_updater.get_pos_features(raw_features)
+                ].module.features_updater.get_pos_features(features)
             else:
-                pos_features = self.models["decoder"].features_updater.get_pos_features(
-                    raw_features
+                features = self.models["decoder"].features_updater.get_pos_features(
+                    features
                 )
-            features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
-                2, 0, 1
-            )
-            enhanced_features = pos_features
-            enhanced_features = torch.flatten(
-                enhanced_features, start_dim=2, end_dim=3
-            ).permute(2, 0, 1)
+            features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
+
             output, pred, hidden_predict, cache, weights = self.models["decoder"](
                 features,
-                enhanced_features,
                 simulated_y_pred[:, :-1],
-                reduced_size,
+                [s[:2] for s in batch_data["imgs_reduced_shape"]],
                 [max(y_len) for _ in range(b)],
                 features_size,
                 start=0,
@@ -1003,7 +999,7 @@ class Manager(GenericTrainingManager):
                 keep_all_weights=True,
             )
 
-            loss_ce = loss_func(pred, y[:, 1:])
+            loss_ce = loss_func(pred, batch_data["labels"][:, 1:])
             sum_loss += loss_ce
             with autocast(enabled=False):
                 self.backward_loss(sum_loss)
@@ -1028,7 +1024,6 @@ class Manager(GenericTrainingManager):
 
     def evaluate_batch(self, batch_data, metric_names):
         x = batch_data["imgs"].to(self.device)
-        reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]]
 
         max_chars = self.params["dataset"]["max_char_prediction"]
 
@@ -1075,28 +1070,22 @@ class Manager(GenericTrainingManager):
             else:
                 features = self.models["encoder"](x)
             features_size = features.size()
+
             if self.device_params["use_ddp"]:
-                pos_features = self.models[
+                features = self.models[
                     "decoder"
                 ].module.features_updater.get_pos_features(features)
             else:
-                pos_features = self.models["decoder"].features_updater.get_pos_features(
+                features = self.models["decoder"].features_updater.get_pos_features(
                     features
                 )
-            features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
-                2, 0, 1
-            )
-            enhanced_features = pos_features
-            enhanced_features = torch.flatten(
-                enhanced_features, start_dim=2, end_dim=3
-            ).permute(2, 0, 1)
+            features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
 
             for i in range(0, max_chars):
                 output, pred, hidden_predict, cache, weights = self.models["decoder"](
                     features,
-                    enhanced_features,
                     predicted_tokens,
-                    reduced_size,
+                    [s[:2] for s in batch_data["imgs_reduced_shape"]],
                     predicted_tokens_len,
                     features_size,
                     start=0,
diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py
index 4da09183..faa33839 100644
--- a/dan/ocr/predict/prediction.py
+++ b/dan/ocr/predict/prediction.py
@@ -159,14 +159,8 @@ class DAN:
 
             features = self.encoder(input_tensor.float())
             features_size = features.size()
-            pos_features = self.decoder.features_updater.get_pos_features(features)
-            features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
-                2, 0, 1
-            )
-            enhanced_features = pos_features
-            enhanced_features = torch.flatten(
-                enhanced_features, start_dim=2, end_dim=3
-            ).permute(2, 0, 1)
+            features = self.decoder.features_updater.get_pos_features(features)
+            features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
 
             for i in range(0, self.max_chars):
                 (
@@ -177,7 +171,6 @@ class DAN:
                     weights,
                 ) = self.decoder(
                     features,
-                    enhanced_features,
                     predicted_tokens,
                     input_sizes,
                     predicted_tokens_len,
diff --git a/dan/ocr/transforms.py b/dan/ocr/transforms.py
index f17aa900..588c887f 100644
--- a/dan/ocr/transforms.py
+++ b/dan/ocr/transforms.py
@@ -19,13 +19,12 @@ from albumentations.augmentations import (
     ToGray,
 )
 from albumentations.core.transforms_interface import ImageOnlyTransform
-from cv2 import dilate, erode
+from cv2 import dilate, erode, resize
 from numpy import random
-from PIL import Image
 from torch import Tensor
 from torch.distributions.uniform import Uniform
 from torchvision.transforms import Compose, ToPILImage
-from torchvision.transforms.functional import resize
+from torchvision.transforms.functional import resize as resize_tensor
 
 
 class Preprocessing(str, Enum):
@@ -47,7 +46,7 @@ class FixedHeightResize:
 
     def __call__(self, img: Tensor) -> Tensor:
         size = (self.height, self._calc_new_width(img))
-        return resize(img, size, antialias=False)
+        return resize_tensor(img, size, antialias=False)
 
     def _calc_new_width(self, img: Tensor) -> int:
         aspect_ratio = img.shape[2] / img.shape[1]
@@ -64,7 +63,7 @@ class FixedWidthResize:
 
     def __call__(self, img: Tensor) -> Tensor:
         size = (self._calc_new_height(img), self.width)
-        return resize(img, size, antialias=False)
+        return resize_tensor(img, size, antialias=False)
 
     def _calc_new_height(self, img: Tensor) -> int:
         aspect_ratio = img.shape[1] / img.shape[2]
@@ -89,7 +88,7 @@ class MaxResize:
         ratio = min(height_ratio, width_ratio)
         new_width = int(width * ratio)
         new_height = int(height * ratio)
-        return resize(img, (new_height, new_width), antialias=False)
+        return resize_tensor(img, (new_height, new_width), antialias=False)
 
 
 class Dilation:
@@ -142,12 +141,11 @@ class ErosionDilation(ImageOnlyTransform):
         kernel_h = randint(self.min_kernel, self.max_kernel)
         kernel_w = randint(self.min_kernel, self.max_kernel)
         kernel = np.ones((kernel_h, kernel_w), np.uint8)
-        augmented_image = (
+        return (
             Erosion(kernel, iterations=self.iterations)(img)
             if random.random() < 0.5
             else Dilation(kernel=kernel, iterations=self.iterations)(img)
         )
-        return augmented_image
 
 
 class DPIAdjusting(ImageOnlyTransform):
@@ -170,12 +168,7 @@ class DPIAdjusting(ImageOnlyTransform):
 
     def apply(self, img: np.ndarray, **params):
         factor = float(Uniform(self.min_factor, self.max_factor).sample())
-        img = Image.fromarray(img)
-        augmented_image = img.resize(
-            (int(np.ceil(img.width * factor)), int(np.ceil(img.height * factor))),
-            Image.BILINEAR,
-        )
-        return np.array(augmented_image)
+        return resize(img, None, fx=factor, fy=factor)
 
 
 def get_preprocessing_transforms(
-- 
GitLab