Skip to content
Snippets Groups Projects

Limit training and validation steps

Merged Solene Tarride requested to merge limit-steps-in-epoch into main
All threads resolved!
2 files
+ 13
11
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -67,6 +67,13 @@ class GenericTrainingManager:
"nb_logged_images"
]
self.limit_train_steps = self.params["training"]["data"].get(
"limit_train_steps"
)
self.limit_val_steps = self.params["training"]["validation"].get(
"limit_val_steps"
)
self.optimizers = dict()
self.optimizers_named_params_by_group = dict()
self.lr_schedulers = dict()
@@ -667,6 +674,10 @@ class GenericTrainingManager:
pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs))
# iterates over mini-batch data
for ind_batch, batch_data in enumerate(self.dataset.train_loader):
# Limit the number of steps
if self.limit_train_steps and ind_batch > self.limit_train_steps:
break
# train on batch data and compute metrics
batch_values = self.train_batch(batch_data, metric_names)
batch_metrics = self.metric_manager["train"].compute_metrics(
@@ -797,6 +808,10 @@ class GenericTrainingManager:
with torch.no_grad():
# iterate over batch data
for ind_batch, batch_data in enumerate(loader):
# Limit the number of steps
if self.limit_val_steps and ind_batch > self.limit_val_steps:
break
# eval batch data and compute metrics
batch_values = self.evaluate_batch(batch_data, metric_names)
batch_metrics = self.metric_manager[set_name].compute_metrics(
Loading