Skip to content
Snippets Groups Projects

Memory leak: Catch error and retry

Merged Manon Blanco requested to merge memory-leak-catch-error-and-retry into main
All threads resolved!
1 file
+ 51
47
Compare changes
  • Side-by-side
  • Inline
+ 51
47
@@ -636,7 +636,41 @@ class GenericTrainingManager:
Recursive function to use `train_batch` but catch `OutOfMemoryError` and retry on smaller batches
"""
try:
return [self.train_batch(batch_data, metric_names)]
# train on batch data and compute metrics
batch_values = self.train_batch(batch_data, metric_names)
batch_metrics = self.metric_manager["train"].compute_metrics(
batch_values, metric_names
)
batch_metrics["names"] = batch_data["names"]
# Merge metrics if Distributed Data Parallel is used
if self.device_params["use_ddp"]:
batch_metrics = self.merge_ddp_metrics(batch_metrics)
# Update learning rate via scheduler if one is used
if self.params["training"]["lr_schedulers"]:
for model_name in self.models:
key = (
"all"
if "all" in self.params["training"]["lr_schedulers"]
else model_name
)
if (
model_name in self.lr_schedulers
and ind_batch
% self.params["training"]["lr_schedulers"][key][
"step_interval"
]
== 0
):
self.lr_schedulers[model_name].step(
len(batch_metrics["names"])
)
# Update dropout scheduler
self.dropout_scheduler.step(len(batch_metrics["names"]))
self.dropout_scheduler.update_dropout_rate()
# Add batch metrics values to epoch metrics values
self.metric_manager["train"].update_metrics(batch_metrics)
except OutOfMemoryError:
batch_size = len(batch_data["names"])
logger.warning(
@@ -646,19 +680,21 @@ class GenericTrainingManager:
raise
# Split batch by two and retry
return [
for smaller_batch_data in [
{
key: values[: round(batch_size / 2)]
for key, values in batch_data.items()
},
{
key: values[round(batch_size / 2) :]
for key, values in batch_data.items()
},
]:
# Set models trainable
for model_name in self.models:
self.models[model_name].train()
train_batch(smaller_batch_data, metric_names)
for smaller_batch_data in [
{
key: values[: round(batch_size / 2)]
for key, values in batch_data.items()
},
{
key: values[round(batch_size / 2) :]
for key, values in batch_data.items()
},
]
]
# init tensorboard file and output param summary file
if self.is_master:
@@ -699,40 +735,8 @@ class GenericTrainingManager:
# iterates over mini-batch data
for ind_batch, batch_data in enumerate(self.dataset.train_loader):
# train on batch data and compute metrics
for batch_values in train_batch(batch_data, metric_names):
batch_metrics = self.metric_manager["train"].compute_metrics(
batch_values, metric_names
)
batch_metrics["names"] = batch_data["names"]
# Merge metrics if Distributed Data Parallel is used
if self.device_params["use_ddp"]:
batch_metrics = self.merge_ddp_metrics(batch_metrics)
# Update learning rate via scheduler if one is used
if self.params["training"]["lr_schedulers"]:
for model_name in self.models:
key = (
"all"
if "all" in self.params["training"]["lr_schedulers"]
else model_name
)
if (
model_name in self.lr_schedulers
and ind_batch
% self.params["training"]["lr_schedulers"][key][
"step_interval"
]
== 0
):
self.lr_schedulers[model_name].step(
len(batch_metrics["names"])
)
# Update dropout scheduler
self.dropout_scheduler.step(len(batch_metrics["names"]))
self.dropout_scheduler.update_dropout_rate()
# Add batch metrics values to epoch metrics values
self.metric_manager["train"].update_metrics(batch_metrics)
train_batch(batch_data, metric_names)
display_values = self.metric_manager["train"].get_display_values()
pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"]) * self.nb_workers)
Loading