From ead5ff4c6b9a3c59905012147d27eae2cd864771 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Thu, 23 Feb 2023 20:42:46 +0100
Subject: [PATCH] use regex delimiters

---
 dan/predict/attention.py  | 12 +++------
 dan/predict/prediction.py | 54 ++++++++++++++++++++++-----------------
 2 files changed, 34 insertions(+), 32 deletions(-)

diff --git a/dan/predict/attention.py b/dan/predict/attention.py
index aab70dff..33a8c6a9 100644
--- a/dan/predict/attention.py
+++ b/dan/predict/attention.py
@@ -1,4 +1,6 @@
 # -*- coding: utf-8 -*-
+import re
+
 import cv2
 import numpy as np
 from PIL import Image
@@ -18,17 +20,11 @@ def split_text(text, level, word_separators, line_separators):
         offset = 0
     # split into words
     elif level == "word":
-        main_sep = word_separators[0]
-        for other_sep in word_separators[1:]:
-            text = text.replace(other_sep, main_sep)
-        text_split = text.split(main_sep)
+        text_split = re.split(word_separators, text)
         offset = 1
     # split into lines
     elif level == "line":
-        main_sep = line_separators[0]
-        for other_sep in line_separators[1:]:
-            text = text.replace(other_sep, main_sep)
-        text_split = text.split(main_sep)
+        text_split = re.split(line_separators, text)
         offset = 1
     else:
         logger.error("Level should be either 'char', 'word', or 'line'")
diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index 69ce4029..7984acfa 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -2,6 +2,7 @@
 
 import os
 import pickle
+import re
 
 import cv2
 import numpy as np
@@ -17,30 +18,6 @@ from dan.predict.attention import plot_attention
 from dan.utils import read_image, round_floats
 
 
-def compute_prob_by_separator(characters, probabilities, separators=["\n"]):
-    """
-    Split text and confidences using separators and return a list of average confidence scores.
-    :param characters: list of characters.
-    :param probabilities: list of probabilities.
-    :param separators: list of characters to split text. Use ["\n", " "] for word confidences and ["\n"] for line confidences.
-    Returns a list confidence scores.
-    """
-    probs = []
-    prob_split = []
-    text_split = ""
-    for char, prob in zip(characters, probabilities):
-        if char not in separators:
-            prob_split.append(prob)
-            text_split += char
-        elif text_split:
-            probs.append(np.mean(prob_split))
-            prob_split = []
-            text_split = ""
-    if text_split:
-        probs.append(np.mean(prob_split))
-    return probs
-
-
 class DAN:
     """
     The DAN class is used to apply a DAN model.
@@ -226,6 +203,31 @@ class DAN:
         return out
 
 
+def parse_delimiters(delimiters):
+    return re.compile(r"|".join(delimiters))
+
+
+def compute_prob_by_separator(characters, probabilities, separator):
+    """
+    Split text and confidences using separators and return a list of average confidence scores.
+    :param characters: list of characters.
+    :param probabilities: list of probabilities.
+    :param separators: regex for separators. Use parse_delimiters(["\n", " "]) for word confidences and parse_delimiters(["\n"]) for line confidences.
+    Returns a list confidence scores.
+    """
+    # match anything except separators, get start and end index
+    pattern = re.compile(f"[^{separator.pattern}]+")
+    matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)]
+
+    # Iterate over text pieces and compute mean confidence
+    probs = []
+    for match in matches:
+        start = match[0]
+        end = match[1]
+        probs.append(np.mean(probabilities[start:end]))
+    return probs
+
+
 def run(
     image,
     model,
@@ -271,6 +273,10 @@ def run(
     text = prediction["text"][0]
     result = {"text": text}
 
+    # Parse delimiters to regex
+    word_separators = parse_delimiters(word_separators)
+    line_separators = parse_delimiters(line_separators)
+
     # Average character-based confidence scores
     if confidence_score:
         char_confidences = prediction["confidences"][0]
-- 
GitLab