Skip to content
Snippets Groups Projects

Memory leak: Catch error and retry

Merged Manon Blanco requested to merge memory-leak-catch-error-and-retry into main
All threads resolved!
2 files
+ 53
101
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 33
71
@@ -17,7 +17,6 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import yaml
from torch.cuda import OutOfMemoryError
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.nn.init import kaiming_uniform_
@@ -630,75 +629,6 @@ class GenericTrainingManager:
"""
Main training loop
"""
def train_batch(batch_data, metric_names):
"""
Recursive function to use `train_batch` but catch `OutOfMemoryError` and retry on smaller batches
"""
try:
# train on batch data and compute metrics
batch_values = self.train_batch(batch_data, metric_names)
batch_metrics = self.metric_manager["train"].compute_metrics(
batch_values, metric_names
)
batch_metrics["names"] = batch_data["names"]
# Merge metrics if Distributed Data Parallel is used
if self.device_params["use_ddp"]:
batch_metrics = self.merge_ddp_metrics(batch_metrics)
# Update learning rate via scheduler if one is used
if self.params["training"]["lr_schedulers"]:
for model_name in self.models:
key = (
"all"
if "all" in self.params["training"]["lr_schedulers"]
else model_name
)
if (
model_name in self.lr_schedulers
and ind_batch
% self.params["training"]["lr_schedulers"][key][
"step_interval"
]
== 0
):
self.lr_schedulers[model_name].step(
len(batch_metrics["names"])
)
# Update dropout scheduler
self.dropout_scheduler.step(len(batch_metrics["names"]))
self.dropout_scheduler.update_dropout_rate()
# Add batch metrics values to epoch metrics values
self.metric_manager["train"].update_metrics(batch_metrics)
except OutOfMemoryError:
batch_size = len(batch_data["names"])
logger.warning(
f"Failed to train batch size of {batch_size} image(s). Trying with smaller batches..."
)
if batch_size == 1:
raise
# Split batch by two and retry
for smaller_batch_data in [
{
key: values[: round(batch_size / 2)]
for key, values in batch_data.items()
},
{
key: values[round(batch_size / 2) :]
for key, values in batch_data.items()
},
]:
with autocast(enabled=False):
self.zero_optimizers()
# Set models trainable
for model_name in self.models:
self.models[model_name].train()
train_batch(smaller_batch_data, metric_names)
# init tensorboard file and output param summary file
if self.is_master:
self.writer = SummaryWriter(self.paths["results"])
@@ -738,8 +668,40 @@ class GenericTrainingManager:
# iterates over mini-batch data
for ind_batch, batch_data in enumerate(self.dataset.train_loader):
# train on batch data and compute metrics
train_batch(batch_data, metric_names)
batch_values = self.train_batch(batch_data, metric_names)
batch_metrics = self.metric_manager["train"].compute_metrics(
batch_values, metric_names
)
batch_metrics["names"] = batch_data["names"]
# Merge metrics if Distributed Data Parallel is used
if self.device_params["use_ddp"]:
batch_metrics = self.merge_ddp_metrics(batch_metrics)
# Update learning rate via scheduler if one is used
if self.params["training"]["lr_schedulers"]:
for model_name in self.models:
key = (
"all"
if "all" in self.params["training"]["lr_schedulers"]
else model_name
)
if (
model_name in self.lr_schedulers
and ind_batch
% self.params["training"]["lr_schedulers"][key][
"step_interval"
]
== 0
):
self.lr_schedulers[model_name].step(
len(batch_metrics["names"])
)
# Update dropout scheduler
self.dropout_scheduler.step(len(batch_metrics["names"]))
self.dropout_scheduler.update_dropout_rate()
# Add batch metrics values to epoch metrics values
self.metric_manager["train"].update_metrics(batch_metrics)
display_values = self.metric_manager["train"].get_display_values()
pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"]) * self.nb_workers)
Loading