From c2eebea38fc261cce115dc4002172bb2215deb51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com> Date: Fri, 18 Aug 2023 09:15:49 +0000 Subject: [PATCH] Reduce memory usage during training --- dan/ocr/decoder.py | 9 +++-- dan/ocr/manager/ocr.py | 42 +++++++++++----------- dan/ocr/manager/training.py | 65 +++++++++++++++-------------------- dan/ocr/predict/prediction.py | 11 ++---- dan/ocr/transforms.py | 21 ++++------- 5 files changed, 60 insertions(+), 88 deletions(-) diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index 69e372cf..ccf2b4d6 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -337,8 +337,7 @@ class GlobalHTADecoder(Module): def forward( self, - raw_features_1d, - enhanced_features_1d, + features_1d, tokens, reduced_size, token_len, @@ -349,7 +348,7 @@ class GlobalHTADecoder(Module): num_pred=None, keep_all_weights=False, ): - device = raw_features_1d.device + device = features_1d.device # Token to Embedding pos_tokens = self.emb(tokens).permute(0, 2, 1) @@ -393,8 +392,8 @@ class GlobalHTADecoder(Module): output, weights, cache = self.att_decoder( pos_tokens, - memory_key=enhanced_features_1d, - memory_value=raw_features_1d, + memory_key=features_1d, + memory_value=features_1d, tgt_mask=target_mask, memory_mask=memory_mask, tgt_key_padding_mask=key_target_mask, diff --git a/dan/ocr/manager/ocr.py b/dan/ocr/manager/ocr.py index 5e5724d7..00262d9b 100644 --- a/dan/ocr/manager/ocr.py +++ b/dan/ocr/manager/ocr.py @@ -246,34 +246,32 @@ class OCRCollateFunction: self.label_padding_value = padding_token def __call__(self, batch_data): - labels = [batch_data[i]["token_label"] for i in range(len(batch_data))] - labels = pad_sequences_1D(labels, padding_value=self.label_padding_value).long() - - imgs = [ - torch.from_numpy(batch_data[i]["img"]).permute(2, 0, 1) - for i in range(len(batch_data)) - ] - imgs = pad_images(imgs) - formatted_batch_data = { - formatted_key: [batch_data[i][initial_key] for i in range(len(batch_data))] - for formatted_key, initial_key in zip( + "imgs": pad_images( [ - "names", - "labels_len", - "raw_labels", - "imgs_position", - "imgs_reduced_shape", - ], - ["name", "label_len", "label", "img_position", "img_reduced_shape"], - ) + torch.from_numpy(sample["img"]).permute(2, 0, 1) + for sample in batch_data + ] + ), + "labels": pad_sequences_1D( + [sample["token_label"] for sample in batch_data], + padding_value=self.label_padding_value, + ).long(), } formatted_batch_data.update( { - "imgs": imgs, - "labels": labels, + formatted_key: [sample[initial_key] for sample in batch_data] + for formatted_key, initial_key in zip( + [ + "names", + "imgs_position", + "imgs_reduced_shape", + "labels_len", + "raw_labels", + ], + ["name", "img_position", "img_reduced_shape", "label_len", "label"], + ) } ) - return formatted_batch_data diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index 2b13c515..e1c9871d 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -634,7 +634,7 @@ class GenericTrainingManager: ) if "lr" in metric_names: self.writer.add_scalar( - "lr_{}".format(model_name), + "lr/{}".format(model_name), ) # Update dropout scheduler @@ -656,7 +656,9 @@ class GenericTrainingManager: # log metrics in tensorboard file for key in display_values: self.writer.add_scalar( - "{}_{}".format(self.params["dataset"]["train"]["name"], key), + "train/{}_{}".format( + self.params["dataset"]["train"]["name"], key + ), display_values[key], num_epoch, ) @@ -677,7 +679,7 @@ class GenericTrainingManager: if self.is_master: for key in eval_values: self.writer.add_scalar( - "{}_{}".format(valid_set_name, key), + "valid/{}_{}".format(valid_set_name, key), eval_values[key], num_epoch, ) @@ -720,7 +722,7 @@ class GenericTrainingManager: tokens=self.tokens, ) with tqdm(total=len(loader.dataset)) as pbar: - pbar.set_description("Validation E{}".format(self.latest_epoch)) + pbar.set_description("VALID {} - {}".format(self.latest_epoch, set_name)) with torch.no_grad(): # iterate over batch data for ind_batch, batch_data in enumerate(loader): @@ -945,9 +947,8 @@ class Manager(GenericTrainingManager): loss_func = CrossEntropyLoss(ignore_index=self.dataset.tokens["pad"]) sum_loss = 0 - x = batch_data["imgs"].to(self.device) - y = batch_data["labels"].to(self.device) - reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]] + b = batch_data["imgs"].shape[0] + batch_data["labels"] = batch_data["labels"].to(self.device) y_len = batch_data["labels_len"] if "label_noise_scheduler" in self.params["training"]: @@ -963,38 +964,33 @@ class Manager(GenericTrainingManager): ) / self.params["training"]["label_noise_scheduler"]["total_num_steps"] ) - simulated_y_pred, y_len = self.add_label_noise(y, y_len, error_rate) + simulated_y_pred, y_len = self.add_label_noise( + batch_data["labels"], y_len, error_rate + ) else: - simulated_y_pred = y + simulated_y_pred = batch_data["labels"] with autocast(enabled=self.device_params["use_amp"]): hidden_predict = None cache = None - raw_features = self.models["encoder"](x) - features_size = raw_features.size() - b, c, h, w = features_size + features = self.models["encoder"](batch_data["imgs"].to(self.device)) + features_size = features.size() if self.device_params["use_ddp"]: - pos_features = self.models[ + features = self.models[ "decoder" - ].module.features_updater.get_pos_features(raw_features) + ].module.features_updater.get_pos_features(features) else: - pos_features = self.models["decoder"].features_updater.get_pos_features( - raw_features + features = self.models["decoder"].features_updater.get_pos_features( + features ) - features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute( - 2, 0, 1 - ) - enhanced_features = pos_features - enhanced_features = torch.flatten( - enhanced_features, start_dim=2, end_dim=3 - ).permute(2, 0, 1) + features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1) + output, pred, hidden_predict, cache, weights = self.models["decoder"]( features, - enhanced_features, simulated_y_pred[:, :-1], - reduced_size, + [s[:2] for s in batch_data["imgs_reduced_shape"]], [max(y_len) for _ in range(b)], features_size, start=0, @@ -1003,7 +999,7 @@ class Manager(GenericTrainingManager): keep_all_weights=True, ) - loss_ce = loss_func(pred, y[:, 1:]) + loss_ce = loss_func(pred, batch_data["labels"][:, 1:]) sum_loss += loss_ce with autocast(enabled=False): self.backward_loss(sum_loss) @@ -1028,7 +1024,6 @@ class Manager(GenericTrainingManager): def evaluate_batch(self, batch_data, metric_names): x = batch_data["imgs"].to(self.device) - reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]] max_chars = self.params["dataset"]["max_char_prediction"] @@ -1075,28 +1070,22 @@ class Manager(GenericTrainingManager): else: features = self.models["encoder"](x) features_size = features.size() + if self.device_params["use_ddp"]: - pos_features = self.models[ + features = self.models[ "decoder" ].module.features_updater.get_pos_features(features) else: - pos_features = self.models["decoder"].features_updater.get_pos_features( + features = self.models["decoder"].features_updater.get_pos_features( features ) - features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute( - 2, 0, 1 - ) - enhanced_features = pos_features - enhanced_features = torch.flatten( - enhanced_features, start_dim=2, end_dim=3 - ).permute(2, 0, 1) + features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1) for i in range(0, max_chars): output, pred, hidden_predict, cache, weights = self.models["decoder"]( features, - enhanced_features, predicted_tokens, - reduced_size, + [s[:2] for s in batch_data["imgs_reduced_shape"]], predicted_tokens_len, features_size, start=0, diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py index 4da09183..faa33839 100644 --- a/dan/ocr/predict/prediction.py +++ b/dan/ocr/predict/prediction.py @@ -159,14 +159,8 @@ class DAN: features = self.encoder(input_tensor.float()) features_size = features.size() - pos_features = self.decoder.features_updater.get_pos_features(features) - features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute( - 2, 0, 1 - ) - enhanced_features = pos_features - enhanced_features = torch.flatten( - enhanced_features, start_dim=2, end_dim=3 - ).permute(2, 0, 1) + features = self.decoder.features_updater.get_pos_features(features) + features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1) for i in range(0, self.max_chars): ( @@ -177,7 +171,6 @@ class DAN: weights, ) = self.decoder( features, - enhanced_features, predicted_tokens, input_sizes, predicted_tokens_len, diff --git a/dan/ocr/transforms.py b/dan/ocr/transforms.py index f17aa900..588c887f 100644 --- a/dan/ocr/transforms.py +++ b/dan/ocr/transforms.py @@ -19,13 +19,12 @@ from albumentations.augmentations import ( ToGray, ) from albumentations.core.transforms_interface import ImageOnlyTransform -from cv2 import dilate, erode +from cv2 import dilate, erode, resize from numpy import random -from PIL import Image from torch import Tensor from torch.distributions.uniform import Uniform from torchvision.transforms import Compose, ToPILImage -from torchvision.transforms.functional import resize +from torchvision.transforms.functional import resize as resize_tensor class Preprocessing(str, Enum): @@ -47,7 +46,7 @@ class FixedHeightResize: def __call__(self, img: Tensor) -> Tensor: size = (self.height, self._calc_new_width(img)) - return resize(img, size, antialias=False) + return resize_tensor(img, size, antialias=False) def _calc_new_width(self, img: Tensor) -> int: aspect_ratio = img.shape[2] / img.shape[1] @@ -64,7 +63,7 @@ class FixedWidthResize: def __call__(self, img: Tensor) -> Tensor: size = (self._calc_new_height(img), self.width) - return resize(img, size, antialias=False) + return resize_tensor(img, size, antialias=False) def _calc_new_height(self, img: Tensor) -> int: aspect_ratio = img.shape[1] / img.shape[2] @@ -89,7 +88,7 @@ class MaxResize: ratio = min(height_ratio, width_ratio) new_width = int(width * ratio) new_height = int(height * ratio) - return resize(img, (new_height, new_width), antialias=False) + return resize_tensor(img, (new_height, new_width), antialias=False) class Dilation: @@ -142,12 +141,11 @@ class ErosionDilation(ImageOnlyTransform): kernel_h = randint(self.min_kernel, self.max_kernel) kernel_w = randint(self.min_kernel, self.max_kernel) kernel = np.ones((kernel_h, kernel_w), np.uint8) - augmented_image = ( + return ( Erosion(kernel, iterations=self.iterations)(img) if random.random() < 0.5 else Dilation(kernel=kernel, iterations=self.iterations)(img) ) - return augmented_image class DPIAdjusting(ImageOnlyTransform): @@ -170,12 +168,7 @@ class DPIAdjusting(ImageOnlyTransform): def apply(self, img: np.ndarray, **params): factor = float(Uniform(self.min_factor, self.max_factor).sample()) - img = Image.fromarray(img) - augmented_image = img.resize( - (int(np.ceil(img.width * factor)), int(np.ceil(img.height * factor))), - Image.BILINEAR, - ) - return np.array(augmented_image) + return resize(img, None, fx=factor, fy=factor) def get_preprocessing_transforms( -- GitLab