From b1258a41539989bcfb8d4cdeb87645fb37ca99a6 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Mon, 10 Mar 2025 11:39:48 +0100
Subject: [PATCH] Use these variables in managers

---
 dan/ocr/manager/ocr.py      | 12 +++++++-----
 dan/ocr/manager/training.py |  4 ++--
 2 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/dan/ocr/manager/ocr.py b/dan/ocr/manager/ocr.py
index 01beb1a9..fee3519b 100644
--- a/dan/ocr/manager/ocr.py
+++ b/dan/ocr/manager/ocr.py
@@ -10,7 +10,7 @@ import torch
 from torch.utils.data import DataLoader
 from torch.utils.data.distributed import DistributedSampler
 
-from dan import TRAIN_NAME
+from dan import TEST_NAME, TRAIN_NAME, VAL_NAME
 from dan.ocr.manager.dataset import OCRDataset
 from dan.ocr.transforms import get_augmentation_transforms, get_preprocessing_transforms
 from dan.utils import pad_images, pad_sequences_1D
@@ -80,10 +80,12 @@ class OCRDatasetManager:
 
         self.mean, self.std = self.train_dataset.compute_std_mean()
 
-        for custom_name in self.params["val"]:
+        for custom_name in self.params[VAL_NAME]:
             self.valid_datasets[custom_name] = OCRDataset(
-                set_name="val",
-                paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]),
+                set_name=VAL_NAME,
+                paths_and_sets=self.get_paths_and_sets(
+                    self.params[VAL_NAME][custom_name]
+                ),
                 charset=self.charset,
                 tokens=self.tokens,
                 preprocessing_transforms=self.preprocessing,
@@ -169,7 +171,7 @@ class OCRDatasetManager:
                 {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
             )
         self.test_datasets[custom_name] = OCRDataset(
-            set_name="test",
+            set_name=TEST_NAME,
             paths_and_sets=paths_and_sets,
             charset=self.charset,
             tokens=self.tokens,
diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py
index 3dc82e11..d8cc47fa 100644
--- a/dan/ocr/manager/training.py
+++ b/dan/ocr/manager/training.py
@@ -25,7 +25,7 @@ from torch.utils.tensorboard import SummaryWriter
 from torchvision.transforms import Compose, PILToTensor
 from tqdm import tqdm
 
-from dan import TRAIN_NAME
+from dan import TRAIN_NAME, VAL_NAME
 from dan.ocr import wandb
 from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder
 from dan.ocr.encoder import FCN_Encoder
@@ -873,7 +873,7 @@ class GenericTrainingManager:
                 # log metrics in MLflow
                 logging_metrics(
                     display_values,
-                    "val",
+                    VAL_NAME,
                     self.latest_epoch,
                     mlflow_logging,
                     self.is_master,
-- 
GitLab