Skip to content
Snippets Groups Projects
Commit c2eebea3 authored by Mélodie Boillet's avatar Mélodie Boillet Committed by Yoann Schneider
Browse files

Reduce memory usage during training

parent a47466ef
No related branches found
No related tags found
1 merge request!241Reduce memory usage during training
...@@ -337,8 +337,7 @@ class GlobalHTADecoder(Module): ...@@ -337,8 +337,7 @@ class GlobalHTADecoder(Module):
def forward( def forward(
self, self,
raw_features_1d, features_1d,
enhanced_features_1d,
tokens, tokens,
reduced_size, reduced_size,
token_len, token_len,
...@@ -349,7 +348,7 @@ class GlobalHTADecoder(Module): ...@@ -349,7 +348,7 @@ class GlobalHTADecoder(Module):
num_pred=None, num_pred=None,
keep_all_weights=False, keep_all_weights=False,
): ):
device = raw_features_1d.device device = features_1d.device
# Token to Embedding # Token to Embedding
pos_tokens = self.emb(tokens).permute(0, 2, 1) pos_tokens = self.emb(tokens).permute(0, 2, 1)
...@@ -393,8 +392,8 @@ class GlobalHTADecoder(Module): ...@@ -393,8 +392,8 @@ class GlobalHTADecoder(Module):
output, weights, cache = self.att_decoder( output, weights, cache = self.att_decoder(
pos_tokens, pos_tokens,
memory_key=enhanced_features_1d, memory_key=features_1d,
memory_value=raw_features_1d, memory_value=features_1d,
tgt_mask=target_mask, tgt_mask=target_mask,
memory_mask=memory_mask, memory_mask=memory_mask,
tgt_key_padding_mask=key_target_mask, tgt_key_padding_mask=key_target_mask,
......
...@@ -246,34 +246,32 @@ class OCRCollateFunction: ...@@ -246,34 +246,32 @@ class OCRCollateFunction:
self.label_padding_value = padding_token self.label_padding_value = padding_token
def __call__(self, batch_data): def __call__(self, batch_data):
labels = [batch_data[i]["token_label"] for i in range(len(batch_data))]
labels = pad_sequences_1D(labels, padding_value=self.label_padding_value).long()
imgs = [
torch.from_numpy(batch_data[i]["img"]).permute(2, 0, 1)
for i in range(len(batch_data))
]
imgs = pad_images(imgs)
formatted_batch_data = { formatted_batch_data = {
formatted_key: [batch_data[i][initial_key] for i in range(len(batch_data))] "imgs": pad_images(
for formatted_key, initial_key in zip(
[ [
"names", torch.from_numpy(sample["img"]).permute(2, 0, 1)
"labels_len", for sample in batch_data
"raw_labels", ]
"imgs_position", ),
"imgs_reduced_shape", "labels": pad_sequences_1D(
], [sample["token_label"] for sample in batch_data],
["name", "label_len", "label", "img_position", "img_reduced_shape"], padding_value=self.label_padding_value,
) ).long(),
} }
formatted_batch_data.update( formatted_batch_data.update(
{ {
"imgs": imgs, formatted_key: [sample[initial_key] for sample in batch_data]
"labels": labels, for formatted_key, initial_key in zip(
[
"names",
"imgs_position",
"imgs_reduced_shape",
"labels_len",
"raw_labels",
],
["name", "img_position", "img_reduced_shape", "label_len", "label"],
)
} }
) )
return formatted_batch_data return formatted_batch_data
...@@ -634,7 +634,7 @@ class GenericTrainingManager: ...@@ -634,7 +634,7 @@ class GenericTrainingManager:
) )
if "lr" in metric_names: if "lr" in metric_names:
self.writer.add_scalar( self.writer.add_scalar(
"lr_{}".format(model_name), "lr/{}".format(model_name),
) )
# Update dropout scheduler # Update dropout scheduler
...@@ -656,7 +656,9 @@ class GenericTrainingManager: ...@@ -656,7 +656,9 @@ class GenericTrainingManager:
# log metrics in tensorboard file # log metrics in tensorboard file
for key in display_values: for key in display_values:
self.writer.add_scalar( self.writer.add_scalar(
"{}_{}".format(self.params["dataset"]["train"]["name"], key), "train/{}_{}".format(
self.params["dataset"]["train"]["name"], key
),
display_values[key], display_values[key],
num_epoch, num_epoch,
) )
...@@ -677,7 +679,7 @@ class GenericTrainingManager: ...@@ -677,7 +679,7 @@ class GenericTrainingManager:
if self.is_master: if self.is_master:
for key in eval_values: for key in eval_values:
self.writer.add_scalar( self.writer.add_scalar(
"{}_{}".format(valid_set_name, key), "valid/{}_{}".format(valid_set_name, key),
eval_values[key], eval_values[key],
num_epoch, num_epoch,
) )
...@@ -720,7 +722,7 @@ class GenericTrainingManager: ...@@ -720,7 +722,7 @@ class GenericTrainingManager:
tokens=self.tokens, tokens=self.tokens,
) )
with tqdm(total=len(loader.dataset)) as pbar: with tqdm(total=len(loader.dataset)) as pbar:
pbar.set_description("Validation E{}".format(self.latest_epoch)) pbar.set_description("VALID {} - {}".format(self.latest_epoch, set_name))
with torch.no_grad(): with torch.no_grad():
# iterate over batch data # iterate over batch data
for ind_batch, batch_data in enumerate(loader): for ind_batch, batch_data in enumerate(loader):
...@@ -945,9 +947,8 @@ class Manager(GenericTrainingManager): ...@@ -945,9 +947,8 @@ class Manager(GenericTrainingManager):
loss_func = CrossEntropyLoss(ignore_index=self.dataset.tokens["pad"]) loss_func = CrossEntropyLoss(ignore_index=self.dataset.tokens["pad"])
sum_loss = 0 sum_loss = 0
x = batch_data["imgs"].to(self.device) b = batch_data["imgs"].shape[0]
y = batch_data["labels"].to(self.device) batch_data["labels"] = batch_data["labels"].to(self.device)
reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]]
y_len = batch_data["labels_len"] y_len = batch_data["labels_len"]
if "label_noise_scheduler" in self.params["training"]: if "label_noise_scheduler" in self.params["training"]:
...@@ -963,38 +964,33 @@ class Manager(GenericTrainingManager): ...@@ -963,38 +964,33 @@ class Manager(GenericTrainingManager):
) )
/ self.params["training"]["label_noise_scheduler"]["total_num_steps"] / self.params["training"]["label_noise_scheduler"]["total_num_steps"]
) )
simulated_y_pred, y_len = self.add_label_noise(y, y_len, error_rate) simulated_y_pred, y_len = self.add_label_noise(
batch_data["labels"], y_len, error_rate
)
else: else:
simulated_y_pred = y simulated_y_pred = batch_data["labels"]
with autocast(enabled=self.device_params["use_amp"]): with autocast(enabled=self.device_params["use_amp"]):
hidden_predict = None hidden_predict = None
cache = None cache = None
raw_features = self.models["encoder"](x) features = self.models["encoder"](batch_data["imgs"].to(self.device))
features_size = raw_features.size() features_size = features.size()
b, c, h, w = features_size
if self.device_params["use_ddp"]: if self.device_params["use_ddp"]:
pos_features = self.models[ features = self.models[
"decoder" "decoder"
].module.features_updater.get_pos_features(raw_features) ].module.features_updater.get_pos_features(features)
else: else:
pos_features = self.models["decoder"].features_updater.get_pos_features( features = self.models["decoder"].features_updater.get_pos_features(
raw_features features
) )
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute( features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
2, 0, 1
)
enhanced_features = pos_features
enhanced_features = torch.flatten(
enhanced_features, start_dim=2, end_dim=3
).permute(2, 0, 1)
output, pred, hidden_predict, cache, weights = self.models["decoder"]( output, pred, hidden_predict, cache, weights = self.models["decoder"](
features, features,
enhanced_features,
simulated_y_pred[:, :-1], simulated_y_pred[:, :-1],
reduced_size, [s[:2] for s in batch_data["imgs_reduced_shape"]],
[max(y_len) for _ in range(b)], [max(y_len) for _ in range(b)],
features_size, features_size,
start=0, start=0,
...@@ -1003,7 +999,7 @@ class Manager(GenericTrainingManager): ...@@ -1003,7 +999,7 @@ class Manager(GenericTrainingManager):
keep_all_weights=True, keep_all_weights=True,
) )
loss_ce = loss_func(pred, y[:, 1:]) loss_ce = loss_func(pred, batch_data["labels"][:, 1:])
sum_loss += loss_ce sum_loss += loss_ce
with autocast(enabled=False): with autocast(enabled=False):
self.backward_loss(sum_loss) self.backward_loss(sum_loss)
...@@ -1028,7 +1024,6 @@ class Manager(GenericTrainingManager): ...@@ -1028,7 +1024,6 @@ class Manager(GenericTrainingManager):
def evaluate_batch(self, batch_data, metric_names): def evaluate_batch(self, batch_data, metric_names):
x = batch_data["imgs"].to(self.device) x = batch_data["imgs"].to(self.device)
reduced_size = [s[:2] for s in batch_data["imgs_reduced_shape"]]
max_chars = self.params["dataset"]["max_char_prediction"] max_chars = self.params["dataset"]["max_char_prediction"]
...@@ -1075,28 +1070,22 @@ class Manager(GenericTrainingManager): ...@@ -1075,28 +1070,22 @@ class Manager(GenericTrainingManager):
else: else:
features = self.models["encoder"](x) features = self.models["encoder"](x)
features_size = features.size() features_size = features.size()
if self.device_params["use_ddp"]: if self.device_params["use_ddp"]:
pos_features = self.models[ features = self.models[
"decoder" "decoder"
].module.features_updater.get_pos_features(features) ].module.features_updater.get_pos_features(features)
else: else:
pos_features = self.models["decoder"].features_updater.get_pos_features( features = self.models["decoder"].features_updater.get_pos_features(
features features
) )
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute( features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
2, 0, 1
)
enhanced_features = pos_features
enhanced_features = torch.flatten(
enhanced_features, start_dim=2, end_dim=3
).permute(2, 0, 1)
for i in range(0, max_chars): for i in range(0, max_chars):
output, pred, hidden_predict, cache, weights = self.models["decoder"]( output, pred, hidden_predict, cache, weights = self.models["decoder"](
features, features,
enhanced_features,
predicted_tokens, predicted_tokens,
reduced_size, [s[:2] for s in batch_data["imgs_reduced_shape"]],
predicted_tokens_len, predicted_tokens_len,
features_size, features_size,
start=0, start=0,
......
...@@ -159,14 +159,8 @@ class DAN: ...@@ -159,14 +159,8 @@ class DAN:
features = self.encoder(input_tensor.float()) features = self.encoder(input_tensor.float())
features_size = features.size() features_size = features.size()
pos_features = self.decoder.features_updater.get_pos_features(features) features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute( features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
2, 0, 1
)
enhanced_features = pos_features
enhanced_features = torch.flatten(
enhanced_features, start_dim=2, end_dim=3
).permute(2, 0, 1)
for i in range(0, self.max_chars): for i in range(0, self.max_chars):
( (
...@@ -177,7 +171,6 @@ class DAN: ...@@ -177,7 +171,6 @@ class DAN:
weights, weights,
) = self.decoder( ) = self.decoder(
features, features,
enhanced_features,
predicted_tokens, predicted_tokens,
input_sizes, input_sizes,
predicted_tokens_len, predicted_tokens_len,
......
...@@ -19,13 +19,12 @@ from albumentations.augmentations import ( ...@@ -19,13 +19,12 @@ from albumentations.augmentations import (
ToGray, ToGray,
) )
from albumentations.core.transforms_interface import ImageOnlyTransform from albumentations.core.transforms_interface import ImageOnlyTransform
from cv2 import dilate, erode from cv2 import dilate, erode, resize
from numpy import random from numpy import random
from PIL import Image
from torch import Tensor from torch import Tensor
from torch.distributions.uniform import Uniform from torch.distributions.uniform import Uniform
from torchvision.transforms import Compose, ToPILImage from torchvision.transforms import Compose, ToPILImage
from torchvision.transforms.functional import resize from torchvision.transforms.functional import resize as resize_tensor
class Preprocessing(str, Enum): class Preprocessing(str, Enum):
...@@ -47,7 +46,7 @@ class FixedHeightResize: ...@@ -47,7 +46,7 @@ class FixedHeightResize:
def __call__(self, img: Tensor) -> Tensor: def __call__(self, img: Tensor) -> Tensor:
size = (self.height, self._calc_new_width(img)) size = (self.height, self._calc_new_width(img))
return resize(img, size, antialias=False) return resize_tensor(img, size, antialias=False)
def _calc_new_width(self, img: Tensor) -> int: def _calc_new_width(self, img: Tensor) -> int:
aspect_ratio = img.shape[2] / img.shape[1] aspect_ratio = img.shape[2] / img.shape[1]
...@@ -64,7 +63,7 @@ class FixedWidthResize: ...@@ -64,7 +63,7 @@ class FixedWidthResize:
def __call__(self, img: Tensor) -> Tensor: def __call__(self, img: Tensor) -> Tensor:
size = (self._calc_new_height(img), self.width) size = (self._calc_new_height(img), self.width)
return resize(img, size, antialias=False) return resize_tensor(img, size, antialias=False)
def _calc_new_height(self, img: Tensor) -> int: def _calc_new_height(self, img: Tensor) -> int:
aspect_ratio = img.shape[1] / img.shape[2] aspect_ratio = img.shape[1] / img.shape[2]
...@@ -89,7 +88,7 @@ class MaxResize: ...@@ -89,7 +88,7 @@ class MaxResize:
ratio = min(height_ratio, width_ratio) ratio = min(height_ratio, width_ratio)
new_width = int(width * ratio) new_width = int(width * ratio)
new_height = int(height * ratio) new_height = int(height * ratio)
return resize(img, (new_height, new_width), antialias=False) return resize_tensor(img, (new_height, new_width), antialias=False)
class Dilation: class Dilation:
...@@ -142,12 +141,11 @@ class ErosionDilation(ImageOnlyTransform): ...@@ -142,12 +141,11 @@ class ErosionDilation(ImageOnlyTransform):
kernel_h = randint(self.min_kernel, self.max_kernel) kernel_h = randint(self.min_kernel, self.max_kernel)
kernel_w = randint(self.min_kernel, self.max_kernel) kernel_w = randint(self.min_kernel, self.max_kernel)
kernel = np.ones((kernel_h, kernel_w), np.uint8) kernel = np.ones((kernel_h, kernel_w), np.uint8)
augmented_image = ( return (
Erosion(kernel, iterations=self.iterations)(img) Erosion(kernel, iterations=self.iterations)(img)
if random.random() < 0.5 if random.random() < 0.5
else Dilation(kernel=kernel, iterations=self.iterations)(img) else Dilation(kernel=kernel, iterations=self.iterations)(img)
) )
return augmented_image
class DPIAdjusting(ImageOnlyTransform): class DPIAdjusting(ImageOnlyTransform):
...@@ -170,12 +168,7 @@ class DPIAdjusting(ImageOnlyTransform): ...@@ -170,12 +168,7 @@ class DPIAdjusting(ImageOnlyTransform):
def apply(self, img: np.ndarray, **params): def apply(self, img: np.ndarray, **params):
factor = float(Uniform(self.min_factor, self.max_factor).sample()) factor = float(Uniform(self.min_factor, self.max_factor).sample())
img = Image.fromarray(img) return resize(img, None, fx=factor, fy=factor)
augmented_image = img.resize(
(int(np.ceil(img.width * factor)), int(np.ceil(img.height * factor))),
Image.BILINEAR,
)
return np.array(augmented_image)
def get_preprocessing_transforms( def get_preprocessing_transforms(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment