diff --git a/dan/ocr/predict/__init__.py b/dan/ocr/predict/__init__.py
index 1c65254a96a7a661d0b73a2a92ee55260e5aff82..0ee92b8b38fe4fec9915d1a38b967cdc560c3ab2 100644
--- a/dan/ocr/predict/__init__.py
+++ b/dan/ocr/predict/__init__.py
@@ -6,6 +6,7 @@ Predict on an image using a trained DAN model.
 import pathlib
 
 from dan.ocr.predict.prediction import run
+from dan.utils import parse_tokens
 
 
 def add_predict_parser(subcommands) -> None:
@@ -51,13 +52,13 @@ def add_predict_parser(subcommands) -> None:
         help="Path to the output folder.",
         required=True,
     )
+    # Optional arguments.
     parser.add_argument(
         "--tokens",
-        type=pathlib.Path,
-        required=True,
+        type=parse_tokens,
+        required=False,
         help="Path to a yaml file containing a mapping between starting tokens and end tokens. Needed for entities.",
     )
-    # Optional arguments.
     parser.add_argument(
         "--image-extension",
         type=str,
@@ -154,5 +155,10 @@ def add_predict_parser(subcommands) -> None:
         default=1,
         required=False,
     )
-
+    parser.add_argument(
+        "--start-token",
+        help="Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.",
+        type=str,
+        required=False,
+    )
     parser.set_defaults(func=run)
diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py
index 32758f8a29553bffc717c93c896eeb2ad6996010..cdca2a90b6fafa499bac974e3c8c722cad78077e 100644
--- a/dan/ocr/predict/prediction.py
+++ b/dan/ocr/predict/prediction.py
@@ -5,7 +5,7 @@ import pickle
 import re
 from itertools import pairwise
 from pathlib import Path
-from typing import Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple
 
 import numpy as np
 import torch
@@ -22,10 +22,10 @@ from dan.ocr.predict.attention import (
 )
 from dan.ocr.transforms import get_preprocessing_transforms
 from dan.utils import (
+    EntityType,
     ind_to_token,
     list_to_batches,
     pad_images,
-    parse_tokens,
     read_image,
 )
 
@@ -287,7 +287,8 @@ def process_batch(
     predict_objects: bool,
     threshold_method: str,
     threshold_value: int,
-    tokens: Path,
+    tokens: Dict[str, EntityType],
+    start_token: str,
 ) -> None:
     input_images, visu_images, input_sizes = [], [], []
     logger.info("Loading images...")
@@ -320,6 +321,7 @@ def process_batch(
         line_separators=line_separators,
         threshold_method=threshold_method,
         threshold_value=threshold_value,
+        start_token=start_token,
     )
 
     logger.info("Prediction parsing...")
@@ -336,7 +338,6 @@ def process_batch(
             result["confidences"] = {}
             char_confidences = prediction["confidences"][0]
             text = result["text"]
-            tokens = parse_tokens(tokens)
             start_tokens, end_tokens = zip(*list(tokens.values()))
             end_tokens = list(filter(bool, end_tokens))
 
@@ -422,7 +423,8 @@ def run(
     image_extension: str,
     gpu_device: int,
     batch_size: int,
-    tokens: Path,
+    tokens: Dict[str, EntityType],
+    start_token: str,
 ) -> None:
     """
     Predict a single image save the output
@@ -443,6 +445,8 @@ def run(
     :param threshold_value: Thresholding value to use for the "simple" thresholding method.
     :param gpu_device: Use a specific GPU if available.
     :param batch_size: Size of the batches for prediction.
+    :param tokens: NER tokens used.
+    :param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.
     """
     # Create output directory if necessary
     if not output.exists():
@@ -472,4 +476,5 @@ def run(
             threshold_method,
             threshold_value,
             tokens,
+            start_token,
         )
diff --git a/dan/utils.py b/dan/utils.py
index 92954f8a59d48f9d5e7384c10c713912949ffce9..5cb62ff5e39e6eb776eb68d4497f53a25d539144 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 from itertools import islice
 from pathlib import Path
-from typing import NamedTuple
+from typing import Dict, NamedTuple
 
 import torch
 import torchvision.io as torchvision
@@ -106,7 +106,7 @@ def list_to_batches(iterable, n):
         yield batch
 
 
-def parse_tokens(filename: Path) -> dict:
+def parse_tokens(filename: Path) -> Dict[str, EntityType]:
     return {
         name: EntityType(**tokens)
         for name, tokens in yaml.safe_load(filename.read_text()).items()
diff --git a/docs/usage/predict.md b/docs/usage/predict.md
index 294e67aafd54b7a5750d4d99d7df0a6b4cf1a143..d6d46350d3db18aaa35b9ff732d077d48dd767f3 100644
--- a/docs/usage/predict.md
+++ b/docs/usage/predict.md
@@ -4,26 +4,27 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image.
 
 ## Description of parameters
 
-| Parameter                   | Description                                                                                     | Type    | Default       |
-| --------------------------- | ----------------------------------------------------------------------------------------------- | ------- | ------------- |
-| `--image`                   | Path to the image to predict. Must not be provided with `--image-dir`.                          | `Path`  |               |
-| `--image-dir`               | Path to the folder where the images to predict are stored. Must not be provided with `--image`. | `Path`  |               |
-| `--image-extension`         | The extension of the images in the folder. Ignored if `--image-dir` is not provided.            | `str`   | .jpg          |
-| `--model`                   | Path to the model to use for prediction                                                         | `Path`  |               |
-| `--parameters`              | Path to the YAML parameters file.                                                               | `Path`  |               |
-| `--charset`                 | Path to the charset file.                                                                       | `Path`  |               |
-| `--output`                  | Path to the output folder. Results will be saved in this directory.                             | `Path`  |               |
-| `--confidence-score`        | Whether to return confidence scores.                                                            | `bool`  | `False`       |
-| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`.     | `str`   |               |
-| `--attention-map`           | Whether to plot attention maps.                                                                 | `bool`  | `False`       |
-| `--attention-map-scale`     | Image scaling factor before creating the GIF.                                                   | `float` | `0.5`         |
-| `--attention-map-level`     | Level to plot the attention maps. Should be in `["line", "word", "char"]`.                      | `str`   | `"line"`      |
-| `--predict-objects`         | Whether to return polygons coordinates.                                                         | `bool`  | `False`       |
-| `--word-separators`         | List of word separators.                                                                        | `list`  | `[" ", "\n"]` |
-| `--line-separators`         | List of line separators.                                                                        | `list`  | `["\n"]`      |
-| `--threshold-method`        | Method to use for attention mask thresholding. Should be in `["otsu", "simple"]`.               | `str`   | `"otsu"`      |
-| `--threshold-value `        | Threshold to use for the "simple" thresholding method.                                          | `int`   | `0`           |
-| `--batch-size `             | Size of the batches for prediction.                                                             | `int`   | `1`           |
+| Parameter                   | Description                                                                                                                 | Type    | Default       |
+| --------------------------- | --------------------------------------------------------------------------------------------------------------------------- | ------- | ------------- |
+| `--image`                   | Path to the image to predict. Must not be provided with `--image-dir`.                                                      | `Path`  |               |
+| `--image-dir`               | Path to the folder where the images to predict are stored. Must not be provided with `--image`.                             | `Path`  |               |
+| `--image-extension`         | The extension of the images in the folder. Ignored if `--image-dir` is not provided.                                        | `str`   | .jpg          |
+| `--model`                   | Path to the model to use for prediction                                                                                     | `Path`  |               |
+| `--parameters`              | Path to the YAML parameters file.                                                                                           | `Path`  |               |
+| `--charset`                 | Path to the charset file.                                                                                                   | `Path`  |               |
+| `--output`                  | Path to the output folder. Results will be saved in this directory.                                                         | `Path`  |               |
+| `--confidence-score`        | Whether to return confidence scores.                                                                                        | `bool`  | `False`       |
+| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`.                                 | `str`   |               |
+| `--attention-map`           | Whether to plot attention maps.                                                                                             | `bool`  | `False`       |
+| `--attention-map-scale`     | Image scaling factor before creating the GIF.                                                                               | `float` | `0.5`         |
+| `--attention-map-level`     | Level to plot the attention maps. Should be in `["line", "word", "char"]`.                                                  | `str`   | `"line"`      |
+| `--predict-objects`         | Whether to return polygons coordinates.                                                                                     | `bool`  | `False`       |
+| `--word-separators`         | List of word separators.                                                                                                    | `list`  | `[" ", "\n"]` |
+| `--line-separators`         | List of line separators.                                                                                                    | `list`  | `["\n"]`      |
+| `--threshold-method`        | Method to use for attention mask thresholding. Should be in `["otsu", "simple"]`.                                           | `str`   | `"otsu"`      |
+| `--threshold-value `        | Threshold to use for the "simple" thresholding method.                                                                      | `int`   | `0`           |
+| `--batch-size `             | Size of the batches for prediction.                                                                                         | `int`   | `1`           |
+| `--start-token `            | Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages. | `str`   | `None`        |
 
 ## Examples
 
diff --git a/tests/test_prediction.py b/tests/test_prediction.py
index b9e6e8779578a11ee46eb6349d822102441bc4b2..3acbb433db836ead665cd8d16edbc30eb5aaf623 100644
--- a/tests/test_prediction.py
+++ b/tests/test_prediction.py
@@ -7,6 +7,7 @@ import pytest
 
 from dan.ocr.predict.prediction import DAN
 from dan.ocr.predict.prediction import run as run_prediction
+from dan.utils import parse_tokens
 
 
 @pytest.mark.parametrize(
@@ -316,7 +317,8 @@ def test_run_prediction(
         image_extension=None,
         gpu_device=None,
         batch_size=1,
-        tokens=prediction_data_path / "tokens.yml",
+        tokens=parse_tokens(prediction_data_path / "tokens.yml"),
+        start_token=None,
     )
 
     prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text())
@@ -514,7 +516,8 @@ def test_run_prediction_batch(
         image_extension=".png",
         gpu_device=None,
         batch_size=batch_size,
-        tokens=prediction_data_path / "tokens.yml",
+        tokens=parse_tokens(prediction_data_path / "tokens.yml"),
+        start_token=None,
     )
 
     for image_name, expected_prediction in zip(image_names, expected_predictions):