Skip to content
Snippets Groups Projects

Memory leak: Catch error and retry

Merged Manon Blanco requested to merge memory-leak-catch-error-and-retry into main
All threads resolved!
1 file
+ 4
3
Compare changes
  • Side-by-side
  • Inline
@@ -680,9 +680,6 @@ class GenericTrainingManager:
raise
# Split batch by two and retry
with torch.no_grad():
torch.cuda.empty_cache()
for smaller_batch_data in [
{
key: values[: round(batch_size / 2)]
@@ -693,6 +690,10 @@ class GenericTrainingManager:
for key, values in batch_data.items()
},
]:
# Set models trainable
for model_name in self.models:
self.models[model_name].train()
train_batch(smaller_batch_data, metric_names)
# init tensorboard file and output param summary file
Loading