From 986fdaf0f73cff0d807801a4adea1c7ead21f4e6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Fri, 24 Feb 2023 09:47:13 +0000
Subject: [PATCH] Compute confidence scores by char, word or line

---
 dan/predict/__init__.py   | 25 ++++++++++++++-
 dan/predict/attention.py  | 22 ++++++++++----
 dan/predict/prediction.py | 64 ++++++++++++++++++++++++++++++++++++---
 dan/utils.py              |  7 +++++
 docs/usage/predict.md     | 39 +++++++++++-------------
 5 files changed, 123 insertions(+), 34 deletions(-)

diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py
index 92aa0f15..1528d35d 100644
--- a/dan/predict/__init__.py
+++ b/dan/predict/__init__.py
@@ -61,6 +61,14 @@ def add_predict_parser(subcommands) -> None:
         help="Whether to return confidence scores.",
         required=False,
     )
+    parser.add_argument(
+        "--confidence-score-levels",
+        default="",
+        type=str,
+        nargs="+",
+        help="Levels of confidence scores. Should be a list of any combinaison of ['char', 'word', 'line'].",
+        required=False,
+    )
     parser.add_argument(
         "--attention-map",
         action="store_true",
@@ -82,5 +90,20 @@ def add_predict_parser(subcommands) -> None:
         help="Image scaling factor before creating the GIF",
         required=False,
     )
-
+    parser.add_argument(
+        "--word-separators",
+        default=[" ", "\n"],
+        type=str,
+        nargs="+",
+        help="String separators used to split text into words.",
+        required=False,
+    )
+    parser.add_argument(
+        "--line-separators",
+        default=["\n"],
+        type=str,
+        nargs="+",
+        help="String separators used to split text into lines.",
+        required=False,
+    )
     parser.set_defaults(func=run)
diff --git a/dan/predict/attention.py b/dan/predict/attention.py
index bdfe57a1..33a8c6a9 100644
--- a/dan/predict/attention.py
+++ b/dan/predict/attention.py
@@ -1,4 +1,6 @@
 # -*- coding: utf-8 -*-
+import re
+
 import cv2
 import numpy as np
 from PIL import Image
@@ -6,7 +8,7 @@ from PIL import Image
 from dan import logger
 
 
-def split_text(text, level):
+def split_text(text, level, word_separators, line_separators):
     """
     Split text into a list of characters, word, or lines.
     :param text: Text prediction from DAN
@@ -18,19 +20,27 @@ def split_text(text, level):
         offset = 0
     # split into words
     elif level == "word":
-        text = text.replace("\n", " ")
-        text_split = text.split(" ")
+        text_split = re.split(word_separators, text)
         offset = 1
     # split into lines
     elif level == "line":
-        text_split = text.split("\n")
+        text_split = re.split(line_separators, text)
         offset = 1
     else:
         logger.error("Level should be either 'char', 'word', or 'line'")
     return text_split, offset
 
 
-def plot_attention(image, text, weights, level, scale, outname):
+def plot_attention(
+    image,
+    text,
+    weights,
+    level,
+    scale,
+    outname,
+    word_separators=["\n", " "],
+    line_separators=["\n"],
+):
     """
     Create a gif by blending attention maps to the image for each text piece (char, word or line)
     :param image: Input image in PIL format
@@ -48,7 +58,7 @@ def plot_attention(image, text, weights, level, scale, outname):
     image = Image.fromarray(image)
 
     # Split text into characters, words or lines
-    text_list, offset = split_text(text, level)
+    text_list, offset = split_text(text, level, word_separators, line_separators)
 
     # Iterate on characters, words or lines
     tot_len = 0
diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index 0bb427ea..3cff98d7 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -2,6 +2,7 @@
 
 import os
 import pickle
+import re
 
 import cv2
 import numpy as np
@@ -14,7 +15,7 @@ from dan.decoder import GlobalHTADecoder
 from dan.models import FCN_Encoder
 from dan.ocr.utils import LM_ind_to_str
 from dan.predict.attention import plot_attention
-from dan.utils import read_image
+from dan.utils import read_image, round_floats
 
 
 class DAN:
@@ -85,7 +86,13 @@ class DAN:
         input_image = (input_image - self.mean) / self.std
         return input_image
 
-    def predict(self, input_tensor, input_sizes, confidences=False, attentions=False):
+    def predict(
+        self,
+        input_tensor,
+        input_sizes,
+        confidences=False,
+        attentions=False,
+    ):
         """
         Run prediction on an input image.
         :param input_tensor: A batch of images to predict.
@@ -165,11 +172,13 @@ class DAN:
                 if torch.all(reached_end):
                     break
 
+            # Concatenate tensors for each token
             confidence_scores = (
                 torch.cat(confidence_scores, dim=1).cpu().detach().numpy()
             )
             attention_maps = torch.cat(attention_maps, dim=1).cpu().detach().numpy()
 
+            # Remove bot and eot tokens
             predicted_tokens = predicted_tokens[:, 1:]
             prediction_len[torch.eq(reached_end, False)] = self.max_chars - 1
             predicted_tokens = [
@@ -178,9 +187,12 @@ class DAN:
             confidence_scores = [
                 confidence_scores[i, : prediction_len[i]].tolist() for i in range(b)
             ]
+
+            # Transform tokens to characters
             predicted_text = [
                 LM_ind_to_str(self.charset, t, oov_symbol="") for t in predicted_tokens
             ]
+
             logger.info("Images processed")
 
         out = {"text": predicted_text}
@@ -191,6 +203,26 @@ class DAN:
         return out
 
 
+def parse_delimiters(delimiters):
+    return re.compile(r"|".join(delimiters))
+
+
+def compute_prob_by_separator(characters, probabilities, separator):
+    """
+    Split text and confidences using separators and return a list of average confidence scores.
+    :param characters: list of characters.
+    :param probabilities: list of probabilities.
+    :param separators: regex for separators. Use parse_delimiters(["\n", " "]) for word confidences and parse_delimiters(["\n"]) for line confidences.
+    Returns a list confidence scores.
+    """
+    # match anything except separators, get start and end index
+    pattern = re.compile(f"[^{separator.pattern}]+")
+    matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)]
+
+    # Iterate over text pieces and compute mean confidence
+    return [np.mean(probabilities[start:end]) for (start, end) in matches]
+
+
 def run(
     image,
     model,
@@ -199,9 +231,12 @@ def run(
     output,
     scale,
     confidence_score,
+    confidence_score_levels,
     attention_map,
     attention_map_level,
     attention_map_scale,
+    word_separators,
+    line_separators,
 ):
     # Create output directory if necessary
     if not os.path.exists(output):
@@ -230,12 +265,29 @@ def run(
         confidences=confidence_score,
         attentions=attention_map,
     )
-    result = {"text": prediction["text"][0]}
+    text = prediction["text"][0]
+    result = {"text": text}
+
+    # Parse delimiters to regex
+    word_separators = parse_delimiters(word_separators)
+    line_separators = parse_delimiters(line_separators)
 
     # Average character-based confidence scores
     if confidence_score:
-        # TODO: select the level for confidence scores (char, word, line, total)
-        result["confidence"] = np.around(np.mean(prediction["confidences"][0]), 2)
+        char_confidences = prediction["confidences"][0]
+        result["confidences"] = {"total": np.around(np.mean(char_confidences), 2)}
+        if "word" in confidence_score_levels:
+            word_probs = compute_prob_by_separator(
+                text, char_confidences, word_separators
+            )
+            result["confidences"].update({"word": round_floats(word_probs)})
+        if "line" in confidence_score_levels:
+            line_probs = compute_prob_by_separator(
+                text, char_confidences, line_separators
+            )
+            result["confidences"].update({"line": round_floats(line_probs)})
+        if "char" in confidence_score_levels:
+            result["confidences"].update({"char": round_floats(char_confidences)})
 
     # Save gif with attention map
     if attention_map:
@@ -247,6 +299,8 @@ def run(
             weights=prediction["attentions"][0],
             level=attention_map_level,
             scale=attention_map_scale,
+            word_separators=word_separators,
+            line_separators=line_separators,
             outname=gif_filename,
         )
         result["attention_gif"] = gif_filename
diff --git a/dan/utils.py b/dan/utils.py
index 2fd1529a..f2f27dd2 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -199,3 +199,10 @@ def read_image(filename, scale=1.0):
         height = int(image.shape[0] * scale)
         image = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
     return image
+
+
+def round_floats(float_list, decimals=2):
+    """
+    Round list of floats with fixed decimals
+    """
+    return [np.around(num, decimals) for num in float_list]
diff --git a/docs/usage/predict.md b/docs/usage/predict.md
index 4c5d9050..4cf262d5 100644
--- a/docs/usage/predict.md
+++ b/docs/usage/predict.md
@@ -1,22 +1,24 @@
 # Predict
 
-## Description
-
 Use the `teklia-dan predict` command to predict a trained DAN model on an image.
 
-| Parameter                      | Description                                                                  | Type     | Default |
-| ------------------------------ | ---------------------------------------------------------------------------- | -------- | ------- |
-| `--image`                      | Path to the image to predict.                                                | `Path`   |         |
-| `--model`                      | Path to the model to use for prediction                                      | `Path`   |         |
-| `--parameters`                 | Path to the YAML parameters file.                                            | `Path`   |         |
-| `--charset`                    | Path to the charset file.                                                    | `Path`   |         |
-| `--output`                     | Path to the output folder. Results will be saved in this directory.          | `Path`   |         |
-| `--scale`                      | Image scaling factor before feeding it to DAN.                               | `float`  | 1.0     |
-| `--confidence-score`           | Whether to return confidence scores.                                         | `bool`   | False   |
-| `--attention-map`              | Whether to plot attention maps.                                              | `bool`   | False   |
-| `--attention-map-level`        | Level to plot the attention maps. Should be in  `["line", "word", "char"]`.  | `str`    | line    |
-| `--attention-map-scale`        | Image scaling factor before creating the GIF.                                | `float`  | 0.5     |
-
+## Description of parameters
+
+| Parameter                   | Description                                                                                  | Type    | Default       |
+| --------------------------- | -------------------------------------------------------------------------------------------- | ------- | ------------- |
+| `--image`                   | Path to the image to predict.                                                                | `Path`  |               |
+| `--model`                   | Path to the model to use for prediction                                                      | `Path`  |               |
+| `--parameters`              | Path to the YAML parameters file.                                                            | `Path`  |               |
+| `--charset`                 | Path to the charset file.                                                                    | `Path`  |               |
+| `--output`                  | Path to the output folder. Results will be saved in this directory.                          | `Path`  |               |
+| `--scale`                   | Image scaling factor before feeding it to DAN.                                               | `float` | `1.0`         |
+| `--confidence-score`        | Whether to return confidence scores.                                                         | `bool`  | `False`       |
+| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str`   |               |
+| `--attention-map`           | Whether to plot attention maps.                                                              | `bool`  | `False`       |
+| `--attention-map-scale`     | Image scaling factor before creating the GIF.                                                | `float` | `0.5`         |
+| `--attention-map-level`     | Level to plot the attention maps. Should be in `["line", "word", "char"]`.                   | `str`   | `"line"`      |
+| `--word-separators`         | List of word separators.                                                                     | `list`  | `[" ", "\n"]` |
+| `--line-separators`         | List of line separators.                                                                     | `list`  | `["\n"]`      |
 
 ## Examples
 
@@ -101,10 +103,3 @@ It will create the following JSON file named `dan_humu_page/predict/example.json
 <video autoplay>
     <source src="../assets/example_word.gif">
 </video>
-
-## Remarks
-
-The script plotting attention maps assumes that:
-
-* words are separated with the symbol ` `
-* lines are separated with the symbol `\n`
-- 
GitLab