From cc85e2b833a931ec328406652a23ba3130655907 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Thu, 6 Mar 2025 12:03:30 +0100
Subject: [PATCH] Specify sets to evaluate

---
 dan/__init__.py                  |  5 +++++
 dan/datasets/download/images.py  |  2 +-
 dan/datasets/extract/arkindex.py |  6 +-----
 dan/ocr/evaluate.py              | 23 +++++++++++++++++++----
 dan/ocr/manager/dataset.py       |  3 ++-
 dan/ocr/manager/ocr.py           |  5 +++--
 dan/ocr/manager/training.py      | 21 ++++++++++++++-------
 tests/conftest.py                |  2 +-
 tests/test_db.py                 |  2 +-
 9 files changed, 47 insertions(+), 22 deletions(-)

diff --git a/dan/__init__.py b/dan/__init__.py
index e6d4269a..ef68a7cb 100644
--- a/dan/__init__.py
+++ b/dan/__init__.py
@@ -8,3 +8,8 @@ logging.basicConfig(
     level=logging.INFO,
     format="%(asctime)s %(levelname)s/%(name)s: %(message)s",
 )
+
+TRAIN_NAME = "train"
+VAL_NAME = "val"
+TEST_NAME = "test"
+SPLIT_NAMES = [TRAIN_NAME, VAL_NAME, TEST_NAME]
diff --git a/dan/datasets/download/images.py b/dan/datasets/download/images.py
index 087c5178..90675324 100644
--- a/dan/datasets/download/images.py
+++ b/dan/datasets/download/images.py
@@ -18,12 +18,12 @@ import numpy as np
 from PIL import Image
 from tqdm import tqdm
 
+from dan import TRAIN_NAME
 from dan.datasets.download.exceptions import ImageDownloadError
 from dan.datasets.download.utils import (
     download_image,
     get_bbox,
 )
-from dan.datasets.extract.arkindex import TRAIN_NAME
 from line_image_extractor.extractor import extract
 from line_image_extractor.image_utils import (
     BoundingBox,
diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py
index e06310ab..58998faf 100644
--- a/dan/datasets/extract/arkindex.py
+++ b/dan/datasets/extract/arkindex.py
@@ -14,6 +14,7 @@ from uuid import UUID
 from tqdm import tqdm
 
 from arkindex_export import Dataset, DatasetElement, Element, open_database
+from dan import SPLIT_NAMES, TEST_NAME, TRAIN_NAME, VAL_NAME
 from dan.datasets.extract.db import (
     get_dataset_elements,
     get_elements,
@@ -32,11 +33,6 @@ from dan.datasets.extract.utils import (
 )
 from dan.utils import parse_tokens
 
-TRAIN_NAME = "train"
-VAL_NAME = "val"
-TEST_NAME = "test"
-SPLIT_NAMES = [TRAIN_NAME, VAL_NAME, TEST_NAME]
-
 logger = logging.getLogger(__name__)
 
 
diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py
index e13259b6..ce88f315 100644
--- a/dan/ocr/evaluate.py
+++ b/dan/ocr/evaluate.py
@@ -25,6 +25,7 @@ from nerval.utils import TABLE_HEADER as NERVAL_TABLE_HEADER
 from nerval.utils import print_results
 from prettytable import MARKDOWN, PrettyTable
 
+from dan import SPLIT_NAMES
 from dan.bio import convert
 from dan.ocr import wandb
 from dan.ocr.manager.metrics import Inference
@@ -83,6 +84,14 @@ def add_evaluate_parser(subcommands) -> None:
         type=Path,
     )
 
+    parser.add_argument(
+        "--sets",
+        dest="set_names",
+        help="Where to save evaluation results in JSON format.",
+        default=SPLIT_NAMES,
+        nargs="+",
+    )
+
     parser.set_defaults(func=run)
 
 
@@ -196,6 +205,7 @@ def eval(
     nerval_threshold: float,
     output_json: Path | None,
     mlflow_logging: bool,
+    set_names: list[str],
 ):
     torch.manual_seed(0)
     torch.cuda.manual_seed(0)
@@ -227,7 +237,7 @@ def eval(
     metrics_values, all_inferences = [], {}
 
     for dataset_name in config["dataset"]["datasets"]:
-        for set_name in ["train", "val", "test"]:
+        for set_name in set_names:
             logger.info(f"Evaluating on set `{set_name}`")
             metrics, inferences = model.evaluate(
                 "{}-{}".format(dataset_name, set_name),
@@ -272,7 +282,12 @@ def eval(
             wandb.log_artifact(artifact, local_path=output_json, name=output_json.name)
 
 
-def run(config: dict, nerval_threshold: float, output_json: Path | None):
+def run(
+    config: dict,
+    nerval_threshold: float,
+    output_json: Path | None,
+    set_names: list[str] = SPLIT_NAMES,
+):
     update_config(config)
 
     # Start "Weights & Biases" as soon as possible
@@ -294,8 +309,8 @@ def run(config: dict, nerval_threshold: float, output_json: Path | None):
     ):
         mp.spawn(
             eval,
-            args=(config, nerval_threshold, output_json, mlflow_logging),
+            args=(config, nerval_threshold, output_json, mlflow_logging, set_names),
             nprocs=config["training"]["device"]["nb_gpu"],
         )
     else:
-        eval(0, config, nerval_threshold, output_json, mlflow_logging)
+        eval(0, config, nerval_threshold, output_json, mlflow_logging, set_names)
diff --git a/dan/ocr/manager/dataset.py b/dan/ocr/manager/dataset.py
index b9b2e6ed..5790207a 100644
--- a/dan/ocr/manager/dataset.py
+++ b/dan/ocr/manager/dataset.py
@@ -10,6 +10,7 @@ import numpy as np
 from torch.utils.data import Dataset
 from tqdm import tqdm
 
+from dan import TRAIN_NAME
 from dan.datasets.utils import natural_sort
 from dan.utils import check_valid_size, read_image, token_to_ind
 
@@ -171,7 +172,7 @@ class OCRDataset(Dataset):
         """
         image_reduced_shape = np.ceil(img.shape / self.reduce_dims_factor).astype(int)
 
-        if self.set_name == "train":
+        if self.set_name == TRAIN_NAME:
             image_reduced_shape = [max(1, t) for t in image_reduced_shape]
 
         image_position = [
diff --git a/dan/ocr/manager/ocr.py b/dan/ocr/manager/ocr.py
index 42224ca7..01beb1a9 100644
--- a/dan/ocr/manager/ocr.py
+++ b/dan/ocr/manager/ocr.py
@@ -10,6 +10,7 @@ import torch
 from torch.utils.data import DataLoader
 from torch.utils.data.distributed import DistributedSampler
 
+from dan import TRAIN_NAME
 from dan.ocr.manager.dataset import OCRDataset
 from dan.ocr.transforms import get_augmentation_transforms, get_preprocessing_transforms
 from dan.utils import pad_images, pad_sequences_1D
@@ -66,8 +67,8 @@ class OCRDatasetManager:
         Load training and validation datasets
         """
         self.train_dataset = OCRDataset(
-            set_name="train",
-            paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]),
+            set_name=TRAIN_NAME,
+            paths_and_sets=self.get_paths_and_sets(self.params[TRAIN_NAME]["datasets"]),
             charset=self.charset,
             tokens=self.tokens,
             preprocessing_transforms=self.preprocessing,
diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py
index 1dc893c5..3dc82e11 100644
--- a/dan/ocr/manager/training.py
+++ b/dan/ocr/manager/training.py
@@ -25,6 +25,7 @@ from torch.utils.tensorboard import SummaryWriter
 from torchvision.transforms import Compose, PILToTensor
 from tqdm import tqdm
 
+from dan import TRAIN_NAME
 from dan.ocr import wandb
 from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder
 from dan.ocr.encoder import FCN_Encoder
@@ -645,7 +646,7 @@ class GenericTrainingManager:
 
         # init variables
         nb_epochs = self.params["training"]["max_nb_epochs"]
-        metric_names = self.params["training"]["metrics"]["train"]
+        metric_names = self.params["training"]["metrics"][TRAIN_NAME]
 
         display_values = None
         # perform epochs
@@ -669,7 +670,7 @@ class GenericTrainingManager:
                     self.latest_epoch
                 )
             # init epoch metrics values
-            self.metric_manager["train"] = MetricManager(
+            self.metric_manager[TRAIN_NAME] = MetricManager(
                 metric_names=metric_names,
                 dataset_name=self.dataset_name,
                 tokens=self.tokens,
@@ -684,7 +685,7 @@ class GenericTrainingManager:
 
                     # train on batch data and compute metrics
                     batch_values = self.train_batch(batch_data, metric_names)
-                    batch_metrics = self.metric_manager["train"].compute_metrics(
+                    batch_metrics = self.metric_manager[TRAIN_NAME].compute_metrics(
                         batch_values, metric_names
                     )
                     batch_metrics["names"] = batch_data["names"]
@@ -716,14 +717,20 @@ class GenericTrainingManager:
                     self.dropout_scheduler.update_dropout_rate()
 
                     # Add batch metrics values to epoch metrics values
-                    self.metric_manager["train"].update_metrics(batch_metrics)
-                    display_values = self.metric_manager["train"].get_display_values()
+                    self.metric_manager[TRAIN_NAME].update_metrics(batch_metrics)
+                    display_values = self.metric_manager[
+                        TRAIN_NAME
+                    ].get_display_values()
                     pbar.set_postfix(values=str(display_values))
                     pbar.update(len(batch_data["names"]) * self.nb_workers)
 
                 # Log MLflow metrics
                 logging_metrics(
-                    display_values, "train", num_epoch, mlflow_logging, self.is_master
+                    display_values,
+                    TRAIN_NAME,
+                    num_epoch,
+                    mlflow_logging,
+                    self.is_master,
                 )
 
             if self.is_master:
@@ -731,7 +738,7 @@ class GenericTrainingManager:
                 for key in display_values:
                     self.writer.add_scalar(
                         "train/{}_{}".format(
-                            self.params["dataset"]["train"]["name"], key
+                            self.params["dataset"][TRAIN_NAME]["name"], key
                         ),
                         display_values[key],
                         num_epoch,
diff --git a/tests/conftest.py b/tests/conftest.py
index 5afdc424..f66df1fb 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -24,7 +24,7 @@ from arkindex_export import (
     WorkerVersion,
     database,
 )
-from dan.datasets.extract.arkindex import TEST_NAME, TRAIN_NAME, VAL_NAME
+from dan import TEST_NAME, TRAIN_NAME, VAL_NAME
 from tests import FIXTURES
 
 
diff --git a/tests/test_db.py b/tests/test_db.py
index 448376ea..6b05415c 100644
--- a/tests/test_db.py
+++ b/tests/test_db.py
@@ -8,7 +8,7 @@ from operator import itemgetter
 import pytest
 
 from arkindex_export import Dataset, DatasetElement, Element
-from dan.datasets.extract.arkindex import TRAIN_NAME
+from dan import TRAIN_NAME
 from dan.datasets.extract.db import (
     get_dataset_elements,
     get_elements,
-- 
GitLab