From ded809b1f198e0b41686b9c27117967337d8ba3f Mon Sep 17 00:00:00 2001
From: manonBlanco <blanco@teklia.com>
Date: Thu, 17 Aug 2023 16:11:14 +0200
Subject: [PATCH] Rename predict and evaluation functions

---
 dan/ocr/manager/training.py | 10 +++++-----
 dan/ocr/train.py            |  2 +-
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py
index 97d37a23..2b13c515 100644
--- a/dan/ocr/manager/training.py
+++ b/dan/ocr/manager/training.py
@@ -670,7 +670,7 @@ class GenericTrainingManager:
             ):
                 for valid_set_name in self.dataset.valid_loaders:
                     # evaluate set and compute metrics
-                    eval_values = self.evaluate(
+                    eval_values = self.validate(
                         valid_set_name, mlflow_logging=mlflow_logging
                     )
                     # log valid metrics in tensorboard file
@@ -702,7 +702,7 @@ class GenericTrainingManager:
                 self.save_model(epoch=num_epoch, name="last")
                 self.writer.flush()
 
-    def evaluate(self, set_name, mlflow_logging=False, **kwargs):
+    def validate(self, set_name, mlflow_logging=False, **kwargs):
         """
         Main loop for validation
         """
@@ -720,7 +720,7 @@ class GenericTrainingManager:
             tokens=self.tokens,
         )
         with tqdm(total=len(loader.dataset)) as pbar:
-            pbar.set_description("Evaluation E{}".format(self.latest_epoch))
+            pbar.set_description("Validation E{}".format(self.latest_epoch))
             with torch.no_grad():
                 # iterate over batch data
                 for ind_batch, batch_data in enumerate(loader):
@@ -751,7 +751,7 @@ class GenericTrainingManager:
                 )
         return display_values
 
-    def predict(
+    def evaluate(
         self, custom_name, sets_list, metric_names, mlflow_logging=False, output=False
     ):
         """
@@ -772,7 +772,7 @@ class GenericTrainingManager:
         )
 
         with tqdm(total=len(loader.dataset)) as pbar:
-            pbar.set_description("Prediction")
+            pbar.set_description("Evaluation")
             with torch.no_grad():
                 for ind_batch, batch_data in enumerate(loader):
                     # iterates over batch data
diff --git a/dan/ocr/train.py b/dan/ocr/train.py
index 7aea40fd..bb0bc48e 100644
--- a/dan/ocr/train.py
+++ b/dan/ocr/train.py
@@ -50,7 +50,7 @@ def train_and_test(rank, params, mlflow_logging=False):
     metrics = ["cer", "wer", "wer_no_punct", "time"]
     for dataset_name in params["dataset"]["datasets"]:
         for set_name in ["test", "val", "train"]:
-            model.predict(
+            model.evaluate(
                 "{}-{}".format(dataset_name, set_name),
                 [
                     (dataset_name, set_name),
-- 
GitLab