From b1258a41539989bcfb8d4cdeb87645fb37ca99a6 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Mon, 10 Mar 2025 11:39:48 +0100 Subject: [PATCH] Use these variables in managers --- dan/ocr/manager/ocr.py | 12 +++++++----- dan/ocr/manager/training.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/dan/ocr/manager/ocr.py b/dan/ocr/manager/ocr.py index 01beb1a9..fee3519b 100644 --- a/dan/ocr/manager/ocr.py +++ b/dan/ocr/manager/ocr.py @@ -10,7 +10,7 @@ import torch from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from dan import TRAIN_NAME +from dan import TEST_NAME, TRAIN_NAME, VAL_NAME from dan.ocr.manager.dataset import OCRDataset from dan.ocr.transforms import get_augmentation_transforms, get_preprocessing_transforms from dan.utils import pad_images, pad_sequences_1D @@ -80,10 +80,12 @@ class OCRDatasetManager: self.mean, self.std = self.train_dataset.compute_std_mean() - for custom_name in self.params["val"]: + for custom_name in self.params[VAL_NAME]: self.valid_datasets[custom_name] = OCRDataset( - set_name="val", - paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]), + set_name=VAL_NAME, + paths_and_sets=self.get_paths_and_sets( + self.params[VAL_NAME][custom_name] + ), charset=self.charset, tokens=self.tokens, preprocessing_transforms=self.preprocessing, @@ -169,7 +171,7 @@ class OCRDatasetManager: {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]} ) self.test_datasets[custom_name] = OCRDataset( - set_name="test", + set_name=TEST_NAME, paths_and_sets=paths_and_sets, charset=self.charset, tokens=self.tokens, diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index 3dc82e11..d8cc47fa 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -25,7 +25,7 @@ from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import Compose, PILToTensor from tqdm import tqdm -from dan import TRAIN_NAME +from dan import TRAIN_NAME, VAL_NAME from dan.ocr import wandb from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder from dan.ocr.encoder import FCN_Encoder @@ -873,7 +873,7 @@ class GenericTrainingManager: # log metrics in MLflow logging_metrics( display_values, - "val", + VAL_NAME, self.latest_epoch, mlflow_logging, self.is_master, -- GitLab