From c5fa54a35efbda482390a2235348bdcdbe8385a4 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Mon, 21 Aug 2023 13:38:49 +0000
Subject: [PATCH] Parse tokens during argparse parsing and expose start_token

---
 dan/ocr/predict/__init__.py   | 14 ++++++++----
 dan/ocr/predict/prediction.py | 15 ++++++++-----
 dan/utils.py                  |  4 ++--
 docs/usage/predict.md         | 41 ++++++++++++++++++-----------------
 tests/test_prediction.py      |  7 ++++--
 5 files changed, 48 insertions(+), 33 deletions(-)

diff --git a/dan/ocr/predict/__init__.py b/dan/ocr/predict/__init__.py
index 1c65254a..0ee92b8b 100644
--- a/dan/ocr/predict/__init__.py
+++ b/dan/ocr/predict/__init__.py
@@ -6,6 +6,7 @@ Predict on an image using a trained DAN model.
 import pathlib
 
 from dan.ocr.predict.prediction import run
+from dan.utils import parse_tokens
 
 
 def add_predict_parser(subcommands) -> None:
@@ -51,13 +52,13 @@ def add_predict_parser(subcommands) -> None:
         help="Path to the output folder.",
         required=True,
     )
+    # Optional arguments.
     parser.add_argument(
         "--tokens",
-        type=pathlib.Path,
-        required=True,
+        type=parse_tokens,
+        required=False,
         help="Path to a yaml file containing a mapping between starting tokens and end tokens. Needed for entities.",
     )
-    # Optional arguments.
     parser.add_argument(
         "--image-extension",
         type=str,
@@ -154,5 +155,10 @@ def add_predict_parser(subcommands) -> None:
         default=1,
         required=False,
     )
-
+    parser.add_argument(
+        "--start-token",
+        help="Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.",
+        type=str,
+        required=False,
+    )
     parser.set_defaults(func=run)
diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py
index 32758f8a..cdca2a90 100644
--- a/dan/ocr/predict/prediction.py
+++ b/dan/ocr/predict/prediction.py
@@ -5,7 +5,7 @@ import pickle
 import re
 from itertools import pairwise
 from pathlib import Path
-from typing import Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple
 
 import numpy as np
 import torch
@@ -22,10 +22,10 @@ from dan.ocr.predict.attention import (
 )
 from dan.ocr.transforms import get_preprocessing_transforms
 from dan.utils import (
+    EntityType,
     ind_to_token,
     list_to_batches,
     pad_images,
-    parse_tokens,
     read_image,
 )
 
@@ -287,7 +287,8 @@ def process_batch(
     predict_objects: bool,
     threshold_method: str,
     threshold_value: int,
-    tokens: Path,
+    tokens: Dict[str, EntityType],
+    start_token: str,
 ) -> None:
     input_images, visu_images, input_sizes = [], [], []
     logger.info("Loading images...")
@@ -320,6 +321,7 @@ def process_batch(
         line_separators=line_separators,
         threshold_method=threshold_method,
         threshold_value=threshold_value,
+        start_token=start_token,
     )
 
     logger.info("Prediction parsing...")
@@ -336,7 +338,6 @@ def process_batch(
             result["confidences"] = {}
             char_confidences = prediction["confidences"][0]
             text = result["text"]
-            tokens = parse_tokens(tokens)
             start_tokens, end_tokens = zip(*list(tokens.values()))
             end_tokens = list(filter(bool, end_tokens))
 
@@ -422,7 +423,8 @@ def run(
     image_extension: str,
     gpu_device: int,
     batch_size: int,
-    tokens: Path,
+    tokens: Dict[str, EntityType],
+    start_token: str,
 ) -> None:
     """
     Predict a single image save the output
@@ -443,6 +445,8 @@ def run(
     :param threshold_value: Thresholding value to use for the "simple" thresholding method.
     :param gpu_device: Use a specific GPU if available.
     :param batch_size: Size of the batches for prediction.
+    :param tokens: NER tokens used.
+    :param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.
     """
     # Create output directory if necessary
     if not output.exists():
@@ -472,4 +476,5 @@ def run(
             threshold_method,
             threshold_value,
             tokens,
+            start_token,
         )
diff --git a/dan/utils.py b/dan/utils.py
index 92954f8a..5cb62ff5 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 from itertools import islice
 from pathlib import Path
-from typing import NamedTuple
+from typing import Dict, NamedTuple
 
 import torch
 import torchvision.io as torchvision
@@ -106,7 +106,7 @@ def list_to_batches(iterable, n):
         yield batch
 
 
-def parse_tokens(filename: Path) -> dict:
+def parse_tokens(filename: Path) -> Dict[str, EntityType]:
     return {
         name: EntityType(**tokens)
         for name, tokens in yaml.safe_load(filename.read_text()).items()
diff --git a/docs/usage/predict.md b/docs/usage/predict.md
index 294e67aa..d6d46350 100644
--- a/docs/usage/predict.md
+++ b/docs/usage/predict.md
@@ -4,26 +4,27 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image.
 
 ## Description of parameters
 
-| Parameter                   | Description                                                                                     | Type    | Default       |
-| --------------------------- | ----------------------------------------------------------------------------------------------- | ------- | ------------- |
-| `--image`                   | Path to the image to predict. Must not be provided with `--image-dir`.                          | `Path`  |               |
-| `--image-dir`               | Path to the folder where the images to predict are stored. Must not be provided with `--image`. | `Path`  |               |
-| `--image-extension`         | The extension of the images in the folder. Ignored if `--image-dir` is not provided.            | `str`   | .jpg          |
-| `--model`                   | Path to the model to use for prediction                                                         | `Path`  |               |
-| `--parameters`              | Path to the YAML parameters file.                                                               | `Path`  |               |
-| `--charset`                 | Path to the charset file.                                                                       | `Path`  |               |
-| `--output`                  | Path to the output folder. Results will be saved in this directory.                             | `Path`  |               |
-| `--confidence-score`        | Whether to return confidence scores.                                                            | `bool`  | `False`       |
-| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`.     | `str`   |               |
-| `--attention-map`           | Whether to plot attention maps.                                                                 | `bool`  | `False`       |
-| `--attention-map-scale`     | Image scaling factor before creating the GIF.                                                   | `float` | `0.5`         |
-| `--attention-map-level`     | Level to plot the attention maps. Should be in `["line", "word", "char"]`.                      | `str`   | `"line"`      |
-| `--predict-objects`         | Whether to return polygons coordinates.                                                         | `bool`  | `False`       |
-| `--word-separators`         | List of word separators.                                                                        | `list`  | `[" ", "\n"]` |
-| `--line-separators`         | List of line separators.                                                                        | `list`  | `["\n"]`      |
-| `--threshold-method`        | Method to use for attention mask thresholding. Should be in `["otsu", "simple"]`.               | `str`   | `"otsu"`      |
-| `--threshold-value `        | Threshold to use for the "simple" thresholding method.                                          | `int`   | `0`           |
-| `--batch-size `             | Size of the batches for prediction.                                                             | `int`   | `1`           |
+| Parameter                   | Description                                                                                                                 | Type    | Default       |
+| --------------------------- | --------------------------------------------------------------------------------------------------------------------------- | ------- | ------------- |
+| `--image`                   | Path to the image to predict. Must not be provided with `--image-dir`.                                                      | `Path`  |               |
+| `--image-dir`               | Path to the folder where the images to predict are stored. Must not be provided with `--image`.                             | `Path`  |               |
+| `--image-extension`         | The extension of the images in the folder. Ignored if `--image-dir` is not provided.                                        | `str`   | .jpg          |
+| `--model`                   | Path to the model to use for prediction                                                                                     | `Path`  |               |
+| `--parameters`              | Path to the YAML parameters file.                                                                                           | `Path`  |               |
+| `--charset`                 | Path to the charset file.                                                                                                   | `Path`  |               |
+| `--output`                  | Path to the output folder. Results will be saved in this directory.                                                         | `Path`  |               |
+| `--confidence-score`        | Whether to return confidence scores.                                                                                        | `bool`  | `False`       |
+| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`.                                 | `str`   |               |
+| `--attention-map`           | Whether to plot attention maps.                                                                                             | `bool`  | `False`       |
+| `--attention-map-scale`     | Image scaling factor before creating the GIF.                                                                               | `float` | `0.5`         |
+| `--attention-map-level`     | Level to plot the attention maps. Should be in `["line", "word", "char"]`.                                                  | `str`   | `"line"`      |
+| `--predict-objects`         | Whether to return polygons coordinates.                                                                                     | `bool`  | `False`       |
+| `--word-separators`         | List of word separators.                                                                                                    | `list`  | `[" ", "\n"]` |
+| `--line-separators`         | List of line separators.                                                                                                    | `list`  | `["\n"]`      |
+| `--threshold-method`        | Method to use for attention mask thresholding. Should be in `["otsu", "simple"]`.                                           | `str`   | `"otsu"`      |
+| `--threshold-value `        | Threshold to use for the "simple" thresholding method.                                                                      | `int`   | `0`           |
+| `--batch-size `             | Size of the batches for prediction.                                                                                         | `int`   | `1`           |
+| `--start-token `            | Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages. | `str`   | `None`        |
 
 ## Examples
 
diff --git a/tests/test_prediction.py b/tests/test_prediction.py
index b9e6e877..3acbb433 100644
--- a/tests/test_prediction.py
+++ b/tests/test_prediction.py
@@ -7,6 +7,7 @@ import pytest
 
 from dan.ocr.predict.prediction import DAN
 from dan.ocr.predict.prediction import run as run_prediction
+from dan.utils import parse_tokens
 
 
 @pytest.mark.parametrize(
@@ -316,7 +317,8 @@ def test_run_prediction(
         image_extension=None,
         gpu_device=None,
         batch_size=1,
-        tokens=prediction_data_path / "tokens.yml",
+        tokens=parse_tokens(prediction_data_path / "tokens.yml"),
+        start_token=None,
     )
 
     prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text())
@@ -514,7 +516,8 @@ def test_run_prediction_batch(
         image_extension=".png",
         gpu_device=None,
         batch_size=batch_size,
-        tokens=prediction_data_path / "tokens.yml",
+        tokens=parse_tokens(prediction_data_path / "tokens.yml"),
+        start_token=None,
     )
 
     for image_name, expected_prediction in zip(image_names, expected_predictions):
-- 
GitLab