diff --git a/dan/manager/training.py b/dan/manager/training.py index f5779a14fb220465773c30ea82ebea3a3e1498ac..67b9f6bf1b38c8dce424c94275676e26403c5248 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -43,17 +43,6 @@ class GenericTrainingManager: self.paths = None self.latest_step = 0 self.latest_epoch = -1 - self.latest_batch = 0 - self.total_batch = 0 - self.grad_acc_step = 0 - self.latest_train_metrics = dict() - self.latest_valid_metrics = dict() - self.curriculum_info = dict() - self.curriculum_info["latest_valid_metrics"] = dict() - self.phase = None - self.max_mem_usage_by_epoch = list() - self.losses = list() - self.lr_values = list() self.scaler = None @@ -512,7 +501,7 @@ class GenericTrainingManager: def backward_loss(self, loss, retain_graph=False): self.scaler.scale(loss).backward(retain_graph=retain_graph) - def step_optimizers(self, increment_step=True, names=None): + def step_optimizers(self, names=None): for model_name in self.optimizers: if names and model_name not in names: continue @@ -559,11 +548,6 @@ class GenericTrainingManager: self.init_curriculum() # perform epochs for num_epoch in range(self.latest_epoch + 1, nb_epochs): - self.dataset.train_dataset.training_info = { - "epoch": self.latest_epoch, - "step": self.latest_step, - } - self.phase = "train" # Check maximum training time stop condition if ( self.params["training_params"]["max_training_time"] @@ -588,8 +572,6 @@ class GenericTrainingManager: pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs)) # iterates over mini-batch data for ind_batch, batch_data in enumerate(self.dataset.train_loader): - self.latest_batch = ind_batch + 1 - self.total_batch += 1 # train on batch data and compute metrics batch_values = self.train_batch(batch_data, metric_names) batch_metrics = self.metric_manager["train"].compute_metrics( @@ -651,7 +633,6 @@ class GenericTrainingManager: display_values[key], num_epoch, ) - self.latest_train_metrics = display_values # evaluate and compute metrics for valid sets if ( @@ -664,7 +645,6 @@ class GenericTrainingManager: eval_values = self.evaluate( valid_set_name, mlflow_logging=mlflow_logging ) - self.latest_valid_metrics = eval_values # log valid metrics in tensorboard file if self.is_master: for key in eval_values.keys(): @@ -716,7 +696,6 @@ class GenericTrainingManager: """ Main loop for validation """ - self.phase = "eval" loader = self.dataset.valid_loaders[set_name] # Set models in eval mode for model_name in self.models.keys(): @@ -733,7 +712,6 @@ class GenericTrainingManager: with torch.no_grad(): # iterate over batch data for ind_batch, batch_data in enumerate(loader): - self.latest_batch = ind_batch + 1 # eval batch data and compute metrics batch_values = self.evaluate_batch(batch_data, metric_names) batch_metrics = self.metric_manager[set_name].compute_metrics( @@ -767,7 +745,6 @@ class GenericTrainingManager: """ Main loop for evaluation """ - self.phase = "predict" metric_names = metric_names.copy() self.dataset.generate_test_loader(custom_name, sets_list) loader = self.dataset.test_loaders[custom_name] @@ -785,7 +762,6 @@ class GenericTrainingManager: with torch.no_grad(): for ind_batch, batch_data in enumerate(loader): # iterates over batch data - self.latest_batch = ind_batch + 1 # eval batch data and compute metrics batch_values = self.evaluate_batch(batch_data, metric_names) batch_metrics = self.metric_manager[custom_name].compute_metrics( @@ -903,10 +879,6 @@ class GenericTrainingManager: dist.all_gather(res, tensor) return list(torch.cat(res, dim=0).flatten().cpu().numpy()) - @staticmethod - def cleanup(): - dist.destroy_process_group() - def train_batch(self, batch_data, metric_names): raise NotImplementedError @@ -916,20 +888,6 @@ class GenericTrainingManager: def init_curriculum(self): raise NotImplementedError - def update_curriculum(self): - raise NotImplementedError - - def add_checkpoint_info(self, load_mode="last", **kwargs): - for filename in os.listdir(self.paths["checkpoints"]): - if load_mode in filename: - checkpoint_path = os.path.join(self.paths["checkpoints"], filename) - checkpoint = torch.load(checkpoint_path) - for key in kwargs.keys(): - checkpoint[key] = kwargs[key] - torch.save(checkpoint, checkpoint_path) - return - self.save_model(self.latest_epoch, "last") - def load_save_info(self, info_dict): """ Load curriculum info from saved model info