From ded809b1f198e0b41686b9c27117967337d8ba3f Mon Sep 17 00:00:00 2001 From: manonBlanco <blanco@teklia.com> Date: Thu, 17 Aug 2023 16:11:14 +0200 Subject: [PATCH] Rename predict and evaluation functions --- dan/ocr/manager/training.py | 10 +++++----- dan/ocr/train.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index 97d37a23..2b13c515 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -670,7 +670,7 @@ class GenericTrainingManager: ): for valid_set_name in self.dataset.valid_loaders: # evaluate set and compute metrics - eval_values = self.evaluate( + eval_values = self.validate( valid_set_name, mlflow_logging=mlflow_logging ) # log valid metrics in tensorboard file @@ -702,7 +702,7 @@ class GenericTrainingManager: self.save_model(epoch=num_epoch, name="last") self.writer.flush() - def evaluate(self, set_name, mlflow_logging=False, **kwargs): + def validate(self, set_name, mlflow_logging=False, **kwargs): """ Main loop for validation """ @@ -720,7 +720,7 @@ class GenericTrainingManager: tokens=self.tokens, ) with tqdm(total=len(loader.dataset)) as pbar: - pbar.set_description("Evaluation E{}".format(self.latest_epoch)) + pbar.set_description("Validation E{}".format(self.latest_epoch)) with torch.no_grad(): # iterate over batch data for ind_batch, batch_data in enumerate(loader): @@ -751,7 +751,7 @@ class GenericTrainingManager: ) return display_values - def predict( + def evaluate( self, custom_name, sets_list, metric_names, mlflow_logging=False, output=False ): """ @@ -772,7 +772,7 @@ class GenericTrainingManager: ) with tqdm(total=len(loader.dataset)) as pbar: - pbar.set_description("Prediction") + pbar.set_description("Evaluation") with torch.no_grad(): for ind_batch, batch_data in enumerate(loader): # iterates over batch data diff --git a/dan/ocr/train.py b/dan/ocr/train.py index 7aea40fd..bb0bc48e 100644 --- a/dan/ocr/train.py +++ b/dan/ocr/train.py @@ -50,7 +50,7 @@ def train_and_test(rank, params, mlflow_logging=False): metrics = ["cer", "wer", "wer_no_punct", "time"] for dataset_name in params["dataset"]["datasets"]: for set_name in ["test", "val", "train"]: - model.predict( + model.evaluate( "{}-{}".format(dataset_name, set_name), [ (dataset_name, set_name), -- GitLab