diff --git a/dan/__init__.py b/dan/__init__.py
index e6d4269a556a156c0d53f907779501090d58dc04..ef68a7cbc608dcde7a7a9f84ea3ff62d10ca5485 100644
--- a/dan/__init__.py
+++ b/dan/__init__.py
@@ -8,3 +8,8 @@ logging.basicConfig(
     level=logging.INFO,
     format="%(asctime)s %(levelname)s/%(name)s: %(message)s",
 )
+
+TRAIN_NAME = "train"
+VAL_NAME = "val"
+TEST_NAME = "test"
+SPLIT_NAMES = [TRAIN_NAME, VAL_NAME, TEST_NAME]
diff --git a/dan/datasets/download/images.py b/dan/datasets/download/images.py
index 087c517897925a7586aaca35467ef43f0f812ccf..906753242d49a23054bc025f9be0c88ee4523427 100644
--- a/dan/datasets/download/images.py
+++ b/dan/datasets/download/images.py
@@ -18,12 +18,12 @@ import numpy as np
 from PIL import Image
 from tqdm import tqdm
 
+from dan import TRAIN_NAME
 from dan.datasets.download.exceptions import ImageDownloadError
 from dan.datasets.download.utils import (
     download_image,
     get_bbox,
 )
-from dan.datasets.extract.arkindex import TRAIN_NAME
 from line_image_extractor.extractor import extract
 from line_image_extractor.image_utils import (
     BoundingBox,
diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py
index e06310abb75fc4011d2bbc6e277f535403b56122..58998faf3735b233bb26302fdf85c16c80e42ff5 100644
--- a/dan/datasets/extract/arkindex.py
+++ b/dan/datasets/extract/arkindex.py
@@ -14,6 +14,7 @@ from uuid import UUID
 from tqdm import tqdm
 
 from arkindex_export import Dataset, DatasetElement, Element, open_database
+from dan import SPLIT_NAMES, TEST_NAME, TRAIN_NAME, VAL_NAME
 from dan.datasets.extract.db import (
     get_dataset_elements,
     get_elements,
@@ -32,11 +33,6 @@ from dan.datasets.extract.utils import (
 )
 from dan.utils import parse_tokens
 
-TRAIN_NAME = "train"
-VAL_NAME = "val"
-TEST_NAME = "test"
-SPLIT_NAMES = [TRAIN_NAME, VAL_NAME, TEST_NAME]
-
 logger = logging.getLogger(__name__)
 
 
diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py
index e13259b687cf0ab3eaff25557785e4ba5cc6e79e..ce88f3153febfefcc99ef55a69b8953a0b4f3fe2 100644
--- a/dan/ocr/evaluate.py
+++ b/dan/ocr/evaluate.py
@@ -25,6 +25,7 @@ from nerval.utils import TABLE_HEADER as NERVAL_TABLE_HEADER
 from nerval.utils import print_results
 from prettytable import MARKDOWN, PrettyTable
 
+from dan import SPLIT_NAMES
 from dan.bio import convert
 from dan.ocr import wandb
 from dan.ocr.manager.metrics import Inference
@@ -83,6 +84,14 @@ def add_evaluate_parser(subcommands) -> None:
         type=Path,
     )
 
+    parser.add_argument(
+        "--sets",
+        dest="set_names",
+        help="Where to save evaluation results in JSON format.",
+        default=SPLIT_NAMES,
+        nargs="+",
+    )
+
     parser.set_defaults(func=run)
 
 
@@ -196,6 +205,7 @@ def eval(
     nerval_threshold: float,
     output_json: Path | None,
     mlflow_logging: bool,
+    set_names: list[str],
 ):
     torch.manual_seed(0)
     torch.cuda.manual_seed(0)
@@ -227,7 +237,7 @@ def eval(
     metrics_values, all_inferences = [], {}
 
     for dataset_name in config["dataset"]["datasets"]:
-        for set_name in ["train", "val", "test"]:
+        for set_name in set_names:
             logger.info(f"Evaluating on set `{set_name}`")
             metrics, inferences = model.evaluate(
                 "{}-{}".format(dataset_name, set_name),
@@ -272,7 +282,12 @@ def eval(
             wandb.log_artifact(artifact, local_path=output_json, name=output_json.name)
 
 
-def run(config: dict, nerval_threshold: float, output_json: Path | None):
+def run(
+    config: dict,
+    nerval_threshold: float,
+    output_json: Path | None,
+    set_names: list[str] = SPLIT_NAMES,
+):
     update_config(config)
 
     # Start "Weights & Biases" as soon as possible
@@ -294,8 +309,8 @@ def run(config: dict, nerval_threshold: float, output_json: Path | None):
     ):
         mp.spawn(
             eval,
-            args=(config, nerval_threshold, output_json, mlflow_logging),
+            args=(config, nerval_threshold, output_json, mlflow_logging, set_names),
             nprocs=config["training"]["device"]["nb_gpu"],
         )
     else:
-        eval(0, config, nerval_threshold, output_json, mlflow_logging)
+        eval(0, config, nerval_threshold, output_json, mlflow_logging, set_names)
diff --git a/dan/ocr/manager/dataset.py b/dan/ocr/manager/dataset.py
index b9b2e6edfd93fbea2d866938b29bb46a5f893c8e..5790207a4e5560823207a27445ff7b2d732f4369 100644
--- a/dan/ocr/manager/dataset.py
+++ b/dan/ocr/manager/dataset.py
@@ -10,6 +10,7 @@ import numpy as np
 from torch.utils.data import Dataset
 from tqdm import tqdm
 
+from dan import TRAIN_NAME
 from dan.datasets.utils import natural_sort
 from dan.utils import check_valid_size, read_image, token_to_ind
 
@@ -171,7 +172,7 @@ class OCRDataset(Dataset):
         """
         image_reduced_shape = np.ceil(img.shape / self.reduce_dims_factor).astype(int)
 
-        if self.set_name == "train":
+        if self.set_name == TRAIN_NAME:
             image_reduced_shape = [max(1, t) for t in image_reduced_shape]
 
         image_position = [
diff --git a/dan/ocr/manager/ocr.py b/dan/ocr/manager/ocr.py
index 42224ca7c65f4fae7620e5edb637ef3fb4252983..01beb1a9bc4ccdb84bcd8fdbf724ee67cc522a21 100644
--- a/dan/ocr/manager/ocr.py
+++ b/dan/ocr/manager/ocr.py
@@ -10,6 +10,7 @@ import torch
 from torch.utils.data import DataLoader
 from torch.utils.data.distributed import DistributedSampler
 
+from dan import TRAIN_NAME
 from dan.ocr.manager.dataset import OCRDataset
 from dan.ocr.transforms import get_augmentation_transforms, get_preprocessing_transforms
 from dan.utils import pad_images, pad_sequences_1D
@@ -66,8 +67,8 @@ class OCRDatasetManager:
         Load training and validation datasets
         """
         self.train_dataset = OCRDataset(
-            set_name="train",
-            paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]),
+            set_name=TRAIN_NAME,
+            paths_and_sets=self.get_paths_and_sets(self.params[TRAIN_NAME]["datasets"]),
             charset=self.charset,
             tokens=self.tokens,
             preprocessing_transforms=self.preprocessing,
diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py
index 1dc893c57c28b6c002a08bf0c0bb4b9cc0e728fa..3dc82e11d2d85b6b8b6e8a14c3a75a55970cde9d 100644
--- a/dan/ocr/manager/training.py
+++ b/dan/ocr/manager/training.py
@@ -25,6 +25,7 @@ from torch.utils.tensorboard import SummaryWriter
 from torchvision.transforms import Compose, PILToTensor
 from tqdm import tqdm
 
+from dan import TRAIN_NAME
 from dan.ocr import wandb
 from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder
 from dan.ocr.encoder import FCN_Encoder
@@ -645,7 +646,7 @@ class GenericTrainingManager:
 
         # init variables
         nb_epochs = self.params["training"]["max_nb_epochs"]
-        metric_names = self.params["training"]["metrics"]["train"]
+        metric_names = self.params["training"]["metrics"][TRAIN_NAME]
 
         display_values = None
         # perform epochs
@@ -669,7 +670,7 @@ class GenericTrainingManager:
                     self.latest_epoch
                 )
             # init epoch metrics values
-            self.metric_manager["train"] = MetricManager(
+            self.metric_manager[TRAIN_NAME] = MetricManager(
                 metric_names=metric_names,
                 dataset_name=self.dataset_name,
                 tokens=self.tokens,
@@ -684,7 +685,7 @@ class GenericTrainingManager:
 
                     # train on batch data and compute metrics
                     batch_values = self.train_batch(batch_data, metric_names)
-                    batch_metrics = self.metric_manager["train"].compute_metrics(
+                    batch_metrics = self.metric_manager[TRAIN_NAME].compute_metrics(
                         batch_values, metric_names
                     )
                     batch_metrics["names"] = batch_data["names"]
@@ -716,14 +717,20 @@ class GenericTrainingManager:
                     self.dropout_scheduler.update_dropout_rate()
 
                     # Add batch metrics values to epoch metrics values
-                    self.metric_manager["train"].update_metrics(batch_metrics)
-                    display_values = self.metric_manager["train"].get_display_values()
+                    self.metric_manager[TRAIN_NAME].update_metrics(batch_metrics)
+                    display_values = self.metric_manager[
+                        TRAIN_NAME
+                    ].get_display_values()
                     pbar.set_postfix(values=str(display_values))
                     pbar.update(len(batch_data["names"]) * self.nb_workers)
 
                 # Log MLflow metrics
                 logging_metrics(
-                    display_values, "train", num_epoch, mlflow_logging, self.is_master
+                    display_values,
+                    TRAIN_NAME,
+                    num_epoch,
+                    mlflow_logging,
+                    self.is_master,
                 )
 
             if self.is_master:
@@ -731,7 +738,7 @@ class GenericTrainingManager:
                 for key in display_values:
                     self.writer.add_scalar(
                         "train/{}_{}".format(
-                            self.params["dataset"]["train"]["name"], key
+                            self.params["dataset"][TRAIN_NAME]["name"], key
                         ),
                         display_values[key],
                         num_epoch,
diff --git a/tests/conftest.py b/tests/conftest.py
index 5afdc424fa2827f86b53952b51e08b25b8182495..f66df1fb895a83c4ccad3353c7a7bc644c7e38b0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -24,7 +24,7 @@ from arkindex_export import (
     WorkerVersion,
     database,
 )
-from dan.datasets.extract.arkindex import TEST_NAME, TRAIN_NAME, VAL_NAME
+from dan import TEST_NAME, TRAIN_NAME, VAL_NAME
 from tests import FIXTURES
 
 
diff --git a/tests/test_db.py b/tests/test_db.py
index 448376ea42176b4e34e75bc84007ea7a072d1954..6b05415c46ea037b4d23d2dc1e1311e7a0f5d7fb 100644
--- a/tests/test_db.py
+++ b/tests/test_db.py
@@ -8,7 +8,7 @@ from operator import itemgetter
 import pytest
 
 from arkindex_export import Dataset, DatasetElement, Element
-from dan.datasets.extract.arkindex import TRAIN_NAME
+from dan import TRAIN_NAME
 from dan.datasets.extract.db import (
     get_dataset_elements,
     get_elements,