diff --git a/dan/manager/training.py b/dan/manager/training.py index 72acce307e2ca0f112584753269c1b41847ed74c..f5779a14fb220465773c30ea82ebea3a3e1498ac 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -532,10 +532,7 @@ class GenericTrainingManager: def zero_optimizers(self, set_to_none=True): for model_name in self.optimizers: - self.zero_optimizer(model_name, set_to_none) - - def zero_optimizer(self, model_name, set_to_none=True): - self.optimizers[model_name].zero_grad(set_to_none=set_to_none) + self.optimizers[model_name].zero_grad(set_to_none=set_to_none) def train(self, mlflow_logging=False): """