Skip to content
Snippets Groups Projects

Limit training and validation steps

Merged Solene Tarride requested to merge limit-steps-in-epoch into main
All threads resolved!
2 files
+ 13
11
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -67,8 +67,8 @@ class GenericTrainingManager:
"nb_logged_images"
]
self.limit_train_steps = self.params["training"]["data"]["limit_train_steps"]
self.limit_val_steps = self.params["training"]["validation"]["limit_val_steps"]
self.limit_train_steps = self.params["training"]["data"].get("limit_train_steps", None)
self.limit_val_steps = self.params["training"]["validation"].get("limit_val_steps", None)
self.optimizers = dict()
self.optimizers_named_params_by_group = dict()
@@ -671,7 +671,7 @@ class GenericTrainingManager:
# iterates over mini-batch data
for ind_batch, batch_data in enumerate(self.dataset.train_loader):
# Limit the number of steps
if ind_batch > self.limit_train_steps:
if self.limit_train_steps and ind_batch > self.limit_train_steps:
break
# train on batch data and compute metrics
@@ -805,7 +805,7 @@ class GenericTrainingManager:
# iterate over batch data
for ind_batch, batch_data in enumerate(loader):
# Limit the number of steps
if ind_batch > self.limit_val_steps:
if self.limit_val_steps and ind_batch > self.limit_val_steps:
break
# eval batch data and compute metrics
Loading