From 96327069ef6cdfdf90644a811591a2b79b8487dd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Thu, 23 Feb 2023 13:51:05 +0100
Subject: [PATCH] add arguments for word and line separators

---
 dan/predict/__init__.py   | 19 +++++++++++++++++--
 dan/predict/attention.py  | 26 ++++++++++++++++++++------
 dan/predict/prediction.py | 12 ++++++++++--
 3 files changed, 47 insertions(+), 10 deletions(-)

diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py
index 4eb9f909..1528d35d 100644
--- a/dan/predict/__init__.py
+++ b/dan/predict/__init__.py
@@ -63,7 +63,7 @@ def add_predict_parser(subcommands) -> None:
     )
     parser.add_argument(
         "--confidence-score-levels",
-        default=[],
+        default="",
         type=str,
         nargs="+",
         help="Levels of confidence scores. Should be a list of any combinaison of ['char', 'word', 'line'].",
@@ -90,5 +90,20 @@ def add_predict_parser(subcommands) -> None:
         help="Image scaling factor before creating the GIF",
         required=False,
     )
-
+    parser.add_argument(
+        "--word-separators",
+        default=[" ", "\n"],
+        type=str,
+        nargs="+",
+        help="String separators used to split text into words.",
+        required=False,
+    )
+    parser.add_argument(
+        "--line-separators",
+        default=["\n"],
+        type=str,
+        nargs="+",
+        help="String separators used to split text into lines.",
+        required=False,
+    )
     parser.set_defaults(func=run)
diff --git a/dan/predict/attention.py b/dan/predict/attention.py
index bdfe57a1..aab70dff 100644
--- a/dan/predict/attention.py
+++ b/dan/predict/attention.py
@@ -6,7 +6,7 @@ from PIL import Image
 from dan import logger
 
 
-def split_text(text, level):
+def split_text(text, level, word_separators, line_separators):
     """
     Split text into a list of characters, word, or lines.
     :param text: Text prediction from DAN
@@ -18,19 +18,33 @@ def split_text(text, level):
         offset = 0
     # split into words
     elif level == "word":
-        text = text.replace("\n", " ")
-        text_split = text.split(" ")
+        main_sep = word_separators[0]
+        for other_sep in word_separators[1:]:
+            text = text.replace(other_sep, main_sep)
+        text_split = text.split(main_sep)
         offset = 1
     # split into lines
     elif level == "line":
-        text_split = text.split("\n")
+        main_sep = line_separators[0]
+        for other_sep in line_separators[1:]:
+            text = text.replace(other_sep, main_sep)
+        text_split = text.split(main_sep)
         offset = 1
     else:
         logger.error("Level should be either 'char', 'word', or 'line'")
     return text_split, offset
 
 
-def plot_attention(image, text, weights, level, scale, outname):
+def plot_attention(
+    image,
+    text,
+    weights,
+    level,
+    scale,
+    outname,
+    word_separators=["\n", " "],
+    line_separators=["\n"],
+):
     """
     Create a gif by blending attention maps to the image for each text piece (char, word or line)
     :param image: Input image in PIL format
@@ -48,7 +62,7 @@ def plot_attention(image, text, weights, level, scale, outname):
     image = Image.fromarray(image)
 
     # Split text into characters, words or lines
-    text_list, offset = split_text(text, level)
+    text_list, offset = split_text(text, level, word_separators, line_separators)
 
     # Iterate on characters, words or lines
     tot_len = 0
diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index a43ed991..69ce4029 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -238,6 +238,8 @@ def run(
     attention_map,
     attention_map_level,
     attention_map_scale,
+    word_separators,
+    line_separators,
 ):
     # Create output directory if necessary
     if not os.path.exists(output):
@@ -274,10 +276,14 @@ def run(
         char_confidences = prediction["confidences"][0]
         result["confidences"] = {"total": np.around(np.mean(char_confidences), 2)}
         if "word" in confidence_score_levels:
-            word_probs = compute_prob_by_separator(text, char_confidences, ["\n", " "])
+            word_probs = compute_prob_by_separator(
+                text, char_confidences, word_separators
+            )
             result["confidences"].update({"word": round_floats(word_probs)})
         if "line" in confidence_score_levels:
-            line_probs = compute_prob_by_separator(text, char_confidences, ["\n"])
+            line_probs = compute_prob_by_separator(
+                text, char_confidences, line_separators
+            )
             result["confidences"].update({"line": round_floats(line_probs)})
         if "char" in confidence_score_levels:
             result["confidences"].update({"char": round_floats(char_confidences)})
@@ -292,6 +298,8 @@ def run(
             weights=prediction["attentions"][0],
             level=attention_map_level,
             scale=attention_map_scale,
+            word_separators=word_separators,
+            line_separators=line_separators,
             outname=gif_filename,
         )
         result["attention_gif"] = gif_filename
-- 
GitLab