diff --git a/dan/__init__.py b/dan/__init__.py index e6d4269a556a156c0d53f907779501090d58dc04..ef68a7cbc608dcde7a7a9f84ea3ff62d10ca5485 100644 --- a/dan/__init__.py +++ b/dan/__init__.py @@ -8,3 +8,8 @@ logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s/%(name)s: %(message)s", ) + +TRAIN_NAME = "train" +VAL_NAME = "val" +TEST_NAME = "test" +SPLIT_NAMES = [TRAIN_NAME, VAL_NAME, TEST_NAME] diff --git a/dan/datasets/download/images.py b/dan/datasets/download/images.py index 087c517897925a7586aaca35467ef43f0f812ccf..906753242d49a23054bc025f9be0c88ee4523427 100644 --- a/dan/datasets/download/images.py +++ b/dan/datasets/download/images.py @@ -18,12 +18,12 @@ import numpy as np from PIL import Image from tqdm import tqdm +from dan import TRAIN_NAME from dan.datasets.download.exceptions import ImageDownloadError from dan.datasets.download.utils import ( download_image, get_bbox, ) -from dan.datasets.extract.arkindex import TRAIN_NAME from line_image_extractor.extractor import extract from line_image_extractor.image_utils import ( BoundingBox, diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py index e06310abb75fc4011d2bbc6e277f535403b56122..58998faf3735b233bb26302fdf85c16c80e42ff5 100644 --- a/dan/datasets/extract/arkindex.py +++ b/dan/datasets/extract/arkindex.py @@ -14,6 +14,7 @@ from uuid import UUID from tqdm import tqdm from arkindex_export import Dataset, DatasetElement, Element, open_database +from dan import SPLIT_NAMES, TEST_NAME, TRAIN_NAME, VAL_NAME from dan.datasets.extract.db import ( get_dataset_elements, get_elements, @@ -32,11 +33,6 @@ from dan.datasets.extract.utils import ( ) from dan.utils import parse_tokens -TRAIN_NAME = "train" -VAL_NAME = "val" -TEST_NAME = "test" -SPLIT_NAMES = [TRAIN_NAME, VAL_NAME, TEST_NAME] - logger = logging.getLogger(__name__) diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py index e13259b687cf0ab3eaff25557785e4ba5cc6e79e..ce88f3153febfefcc99ef55a69b8953a0b4f3fe2 100644 --- a/dan/ocr/evaluate.py +++ b/dan/ocr/evaluate.py @@ -25,6 +25,7 @@ from nerval.utils import TABLE_HEADER as NERVAL_TABLE_HEADER from nerval.utils import print_results from prettytable import MARKDOWN, PrettyTable +from dan import SPLIT_NAMES from dan.bio import convert from dan.ocr import wandb from dan.ocr.manager.metrics import Inference @@ -83,6 +84,14 @@ def add_evaluate_parser(subcommands) -> None: type=Path, ) + parser.add_argument( + "--sets", + dest="set_names", + help="Where to save evaluation results in JSON format.", + default=SPLIT_NAMES, + nargs="+", + ) + parser.set_defaults(func=run) @@ -196,6 +205,7 @@ def eval( nerval_threshold: float, output_json: Path | None, mlflow_logging: bool, + set_names: list[str], ): torch.manual_seed(0) torch.cuda.manual_seed(0) @@ -227,7 +237,7 @@ def eval( metrics_values, all_inferences = [], {} for dataset_name in config["dataset"]["datasets"]: - for set_name in ["train", "val", "test"]: + for set_name in set_names: logger.info(f"Evaluating on set `{set_name}`") metrics, inferences = model.evaluate( "{}-{}".format(dataset_name, set_name), @@ -272,7 +282,12 @@ def eval( wandb.log_artifact(artifact, local_path=output_json, name=output_json.name) -def run(config: dict, nerval_threshold: float, output_json: Path | None): +def run( + config: dict, + nerval_threshold: float, + output_json: Path | None, + set_names: list[str] = SPLIT_NAMES, +): update_config(config) # Start "Weights & Biases" as soon as possible @@ -294,8 +309,8 @@ def run(config: dict, nerval_threshold: float, output_json: Path | None): ): mp.spawn( eval, - args=(config, nerval_threshold, output_json, mlflow_logging), + args=(config, nerval_threshold, output_json, mlflow_logging, set_names), nprocs=config["training"]["device"]["nb_gpu"], ) else: - eval(0, config, nerval_threshold, output_json, mlflow_logging) + eval(0, config, nerval_threshold, output_json, mlflow_logging, set_names) diff --git a/dan/ocr/manager/dataset.py b/dan/ocr/manager/dataset.py index b9b2e6edfd93fbea2d866938b29bb46a5f893c8e..5790207a4e5560823207a27445ff7b2d732f4369 100644 --- a/dan/ocr/manager/dataset.py +++ b/dan/ocr/manager/dataset.py @@ -10,6 +10,7 @@ import numpy as np from torch.utils.data import Dataset from tqdm import tqdm +from dan import TRAIN_NAME from dan.datasets.utils import natural_sort from dan.utils import check_valid_size, read_image, token_to_ind @@ -171,7 +172,7 @@ class OCRDataset(Dataset): """ image_reduced_shape = np.ceil(img.shape / self.reduce_dims_factor).astype(int) - if self.set_name == "train": + if self.set_name == TRAIN_NAME: image_reduced_shape = [max(1, t) for t in image_reduced_shape] image_position = [ diff --git a/dan/ocr/manager/ocr.py b/dan/ocr/manager/ocr.py index 42224ca7c65f4fae7620e5edb637ef3fb4252983..01beb1a9bc4ccdb84bcd8fdbf724ee67cc522a21 100644 --- a/dan/ocr/manager/ocr.py +++ b/dan/ocr/manager/ocr.py @@ -10,6 +10,7 @@ import torch from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from dan import TRAIN_NAME from dan.ocr.manager.dataset import OCRDataset from dan.ocr.transforms import get_augmentation_transforms, get_preprocessing_transforms from dan.utils import pad_images, pad_sequences_1D @@ -66,8 +67,8 @@ class OCRDatasetManager: Load training and validation datasets """ self.train_dataset = OCRDataset( - set_name="train", - paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]), + set_name=TRAIN_NAME, + paths_and_sets=self.get_paths_and_sets(self.params[TRAIN_NAME]["datasets"]), charset=self.charset, tokens=self.tokens, preprocessing_transforms=self.preprocessing, diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index 1dc893c57c28b6c002a08bf0c0bb4b9cc0e728fa..3dc82e11d2d85b6b8b6e8a14c3a75a55970cde9d 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -25,6 +25,7 @@ from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import Compose, PILToTensor from tqdm import tqdm +from dan import TRAIN_NAME from dan.ocr import wandb from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder from dan.ocr.encoder import FCN_Encoder @@ -645,7 +646,7 @@ class GenericTrainingManager: # init variables nb_epochs = self.params["training"]["max_nb_epochs"] - metric_names = self.params["training"]["metrics"]["train"] + metric_names = self.params["training"]["metrics"][TRAIN_NAME] display_values = None # perform epochs @@ -669,7 +670,7 @@ class GenericTrainingManager: self.latest_epoch ) # init epoch metrics values - self.metric_manager["train"] = MetricManager( + self.metric_manager[TRAIN_NAME] = MetricManager( metric_names=metric_names, dataset_name=self.dataset_name, tokens=self.tokens, @@ -684,7 +685,7 @@ class GenericTrainingManager: # train on batch data and compute metrics batch_values = self.train_batch(batch_data, metric_names) - batch_metrics = self.metric_manager["train"].compute_metrics( + batch_metrics = self.metric_manager[TRAIN_NAME].compute_metrics( batch_values, metric_names ) batch_metrics["names"] = batch_data["names"] @@ -716,14 +717,20 @@ class GenericTrainingManager: self.dropout_scheduler.update_dropout_rate() # Add batch metrics values to epoch metrics values - self.metric_manager["train"].update_metrics(batch_metrics) - display_values = self.metric_manager["train"].get_display_values() + self.metric_manager[TRAIN_NAME].update_metrics(batch_metrics) + display_values = self.metric_manager[ + TRAIN_NAME + ].get_display_values() pbar.set_postfix(values=str(display_values)) pbar.update(len(batch_data["names"]) * self.nb_workers) # Log MLflow metrics logging_metrics( - display_values, "train", num_epoch, mlflow_logging, self.is_master + display_values, + TRAIN_NAME, + num_epoch, + mlflow_logging, + self.is_master, ) if self.is_master: @@ -731,7 +738,7 @@ class GenericTrainingManager: for key in display_values: self.writer.add_scalar( "train/{}_{}".format( - self.params["dataset"]["train"]["name"], key + self.params["dataset"][TRAIN_NAME]["name"], key ), display_values[key], num_epoch, diff --git a/tests/conftest.py b/tests/conftest.py index 5afdc424fa2827f86b53952b51e08b25b8182495..f66df1fb895a83c4ccad3353c7a7bc644c7e38b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,7 @@ from arkindex_export import ( WorkerVersion, database, ) -from dan.datasets.extract.arkindex import TEST_NAME, TRAIN_NAME, VAL_NAME +from dan import TEST_NAME, TRAIN_NAME, VAL_NAME from tests import FIXTURES diff --git a/tests/test_db.py b/tests/test_db.py index 448376ea42176b4e34e75bc84007ea7a072d1954..6b05415c46ea037b4d23d2dc1e1311e7a0f5d7fb 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -8,7 +8,7 @@ from operator import itemgetter import pytest from arkindex_export import Dataset, DatasetElement, Element -from dan.datasets.extract.arkindex import TRAIN_NAME +from dan import TRAIN_NAME from dan.datasets.extract.db import ( get_dataset_elements, get_elements,