diff --git a/dan/ocr/manager/ocr.py b/dan/ocr/manager/ocr.py index 01beb1a9bc4ccdb84bcd8fdbf724ee67cc522a21..fee3519bf97346d71d3b1b85efb5d5f0d909f413 100644 --- a/dan/ocr/manager/ocr.py +++ b/dan/ocr/manager/ocr.py @@ -10,7 +10,7 @@ import torch from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from dan import TRAIN_NAME +from dan import TEST_NAME, TRAIN_NAME, VAL_NAME from dan.ocr.manager.dataset import OCRDataset from dan.ocr.transforms import get_augmentation_transforms, get_preprocessing_transforms from dan.utils import pad_images, pad_sequences_1D @@ -80,10 +80,12 @@ class OCRDatasetManager: self.mean, self.std = self.train_dataset.compute_std_mean() - for custom_name in self.params["val"]: + for custom_name in self.params[VAL_NAME]: self.valid_datasets[custom_name] = OCRDataset( - set_name="val", - paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]), + set_name=VAL_NAME, + paths_and_sets=self.get_paths_and_sets( + self.params[VAL_NAME][custom_name] + ), charset=self.charset, tokens=self.tokens, preprocessing_transforms=self.preprocessing, @@ -169,7 +171,7 @@ class OCRDatasetManager: {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]} ) self.test_datasets[custom_name] = OCRDataset( - set_name="test", + set_name=TEST_NAME, paths_and_sets=paths_and_sets, charset=self.charset, tokens=self.tokens, diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index 3dc82e11d2d85b6b8b6e8a14c3a75a55970cde9d..d8cc47fa1af0a3ba487b25bf354a1a6b9e7a0d4f 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -25,7 +25,7 @@ from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import Compose, PILToTensor from tqdm import tqdm -from dan import TRAIN_NAME +from dan import TRAIN_NAME, VAL_NAME from dan.ocr import wandb from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder from dan.ocr.encoder import FCN_Encoder @@ -873,7 +873,7 @@ class GenericTrainingManager: # log metrics in MLflow logging_metrics( display_values, - "val", + VAL_NAME, self.latest_epoch, mlflow_logging, self.is_master,