diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index 97d37a235e61d241df330f07333a24c6f308e2cd..2b13c515ee94c1ecf5722890d9080020832ba22b 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -670,7 +670,7 @@ class GenericTrainingManager: ): for valid_set_name in self.dataset.valid_loaders: # evaluate set and compute metrics - eval_values = self.evaluate( + eval_values = self.validate( valid_set_name, mlflow_logging=mlflow_logging ) # log valid metrics in tensorboard file @@ -702,7 +702,7 @@ class GenericTrainingManager: self.save_model(epoch=num_epoch, name="last") self.writer.flush() - def evaluate(self, set_name, mlflow_logging=False, **kwargs): + def validate(self, set_name, mlflow_logging=False, **kwargs): """ Main loop for validation """ @@ -720,7 +720,7 @@ class GenericTrainingManager: tokens=self.tokens, ) with tqdm(total=len(loader.dataset)) as pbar: - pbar.set_description("Evaluation E{}".format(self.latest_epoch)) + pbar.set_description("Validation E{}".format(self.latest_epoch)) with torch.no_grad(): # iterate over batch data for ind_batch, batch_data in enumerate(loader): @@ -751,7 +751,7 @@ class GenericTrainingManager: ) return display_values - def predict( + def evaluate( self, custom_name, sets_list, metric_names, mlflow_logging=False, output=False ): """ @@ -772,7 +772,7 @@ class GenericTrainingManager: ) with tqdm(total=len(loader.dataset)) as pbar: - pbar.set_description("Prediction") + pbar.set_description("Evaluation") with torch.no_grad(): for ind_batch, batch_data in enumerate(loader): # iterates over batch data diff --git a/dan/ocr/train.py b/dan/ocr/train.py index 7aea40fd144e1fb9e725715aba1e35041256b35e..bb0bc48e094c8145350f64f081d4d212d1ff77f2 100644 --- a/dan/ocr/train.py +++ b/dan/ocr/train.py @@ -50,7 +50,7 @@ def train_and_test(rank, params, mlflow_logging=False): metrics = ["cer", "wer", "wer_no_punct", "time"] for dataset_name in params["dataset"]["datasets"]: for set_name in ["test", "val", "train"]: - model.predict( + model.evaluate( "{}-{}".format(dataset_name, set_name), [ (dataset_name, set_name),