From 04b15f97469b99f5c4760d14ea277b3ac390c100 Mon Sep 17 00:00:00 2001
From: starride-teklia <starride@teklia.com>
Date: Fri, 7 Apr 2023 08:53:14 +0000
Subject: [PATCH] Predict polygons but preserve old behavior

---
 dan/predict/attention.py  | 238 ++++++++++++++++++++++++++------------
 dan/predict/prediction.py | 137 ++++++++++++----------
 2 files changed, 240 insertions(+), 135 deletions(-)

diff --git a/dan/predict/attention.py b/dan/predict/attention.py
index 58564416..38d0032e 100644
--- a/dan/predict/attention.py
+++ b/dan/predict/attention.py
@@ -6,38 +6,17 @@ import numpy as np
 from PIL import Image
 
 from dan import logger
+from dan.utils import round_floats
 
 
-def split_text(text: str, level: str, word_separators, line_separators):
-    """
-    Split text into a list of characters, word, or lines.
-    :param text: Text prediction from DAN
-    :param level: Level to visualize from [char, word, line]
-    """
-    # split into characters
-    if level == "char":
-        text_split = list(text)
-        offset = 0
-    # split into words
-    elif level == "word":
-        text_split = re.split(word_separators, text)
-        offset = 1
-    # split into lines
-    elif level == "line":
-        text_split = re.split(line_separators, text)
-        offset = 1
-    else:
-        logger.error("Level should be either 'char', 'word', or 'line'")
-    return text_split, offset
-
-
-def compute_coverage(text: str, max_value: float, offset: int, attentions):
+def compute_coverage(text: str, max_value: float, offset: int, attentions, size: tuple):
     """
     Aggregates attention maps for the current text piece (char, word, line)
     :param text: Text piece selected with offset after splitting DAN prediction
     :param max_value: Maximum "attention intensity" for parts of a text piece, used for normalization
     :param offset: Offset value to get the relevant part of text piece
     :param attentions: Attention weights of size (n_char, feature_height, feature_width)
+    :param size: Target size (width, height) to resize the coverage vector
     """
     _, height, width = attentions.shape
 
@@ -49,6 +28,11 @@ def compute_coverage(text: str, max_value: float, offset: int, attentions):
 
     # Normalize coverage vector
     coverage_vector = (coverage_vector / max_value * 255).astype(np.uint8)
+
+    # Resize it
+    if size:
+        coverage_vector = cv2.resize(coverage_vector, size)
+
     return coverage_vector
 
 
@@ -74,9 +58,82 @@ def blend_coverage(coverage_vector, image, mask, scale):
     return blend
 
 
-def get_predicted_polygons(
+def parse_delimiters(delimiters):
+    return re.compile(r"|".join(delimiters))
+
+
+def compute_prob_by_separator(characters, probabilities, separator):
+    """
+    Split text and confidences using separators and return a list of average confidence scores.
+    :param characters: list of characters.
+    :param probabilities: list of character probabilities.
+    :param separators: regex for separators. Use parse_delimiters(["\n", " "]) for word confidences and parse_delimiters(["\n"]) for line confidences.
+    Returns a list confidence scores.
+    """
+    # match anything except separators, get start and end index
+    pattern = re.compile(f"[^{separator.pattern}]+")
+    matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)]
+
+    # Iterate over text pieces and compute mean confidence
+    probs = [np.mean(probabilities[start:end]) for (start, end) in matches]
+    texts = [characters[start:end] for (start, end) in matches]
+    return texts, probs
+
+
+def split_text(text: str, level: str, word_separators, line_separators):
+    """
+    Split text into a list of characters, word, or lines.
+    :param text: Text prediction from DAN
+    :param level: Level to visualize from [char, word, line]
+    :param word_separators: List of word separators
+    :param line_separators: List of line separators
+    """
+    if level == "char":
+        text_split = list(text)
+        offset = 0
+
+    # split into words
+    elif level == "word":
+        text_split = re.split(word_separators, text)
+        offset = 1
+    # split into lines
+    elif level == "line":
+        text_split = re.split(line_separators, text)
+        offset = 1
+    else:
+        logger.error("Level should be either 'char', 'word', or 'line'")
+    return text_split, offset
+
+
+def split_text_and_confidences(
+    text, confidences, level, word_separators, line_separators
+):
+    """
+    Split text into a list of characters, words or lines with corresponding confidences scores
+    :param text: Text prediction from DAN
+    :param confidences: Character confidences
+    :param level: Level to visualize from [char, word, line]
+    :param word_separators: List of word separators
+    :param line_separators: List of line separators
+    """
+    if level == "char":
+        texts = list(text)
+        offset = 0
+    elif level == "word":
+        texts, probs = compute_prob_by_separator(text, confidences, word_separators)
+        offset = 1
+    elif level == "line":
+        texts, probs = compute_prob_by_separator(text, confidences, line_separators)
+        offset = 1
+    else:
+        logger.error("Level should be either 'char', 'word', or 'line'")
+    return texts, round_floats(probs), offset
+
+
+def get_predicted_polygons_with_confidence(
     text,
     weights,
+    confidences,
     level,
     height,
     width,
@@ -87,63 +144,85 @@ def get_predicted_polygons(
     Returns the polygons of each object of the current prediction
     :param text: Text predicted by DAN
     :param weights: Attention weights of size (n_char, feature_height, feature_width)
+    :param confidences: Character confidences
     :param level: Level to display (must be in [char, word, line])
     :param height: Original image height
     :param width: Original image width
+    :param word_separators: List of word separators
+    :param line_separators: List of line separators
     """
     # Split text into characters, words or lines
-    text_list, offset = split_text(text, level, word_separators, line_separators)
-    max_value = weights.sum(0).max()
+    text_list, confidence_list, offset = split_text_and_confidences(
+        text, confidences, level, word_separators, line_separators
+    )
 
-    # Set offset based on current text_piece to be used.
-    return [
-        get_polygon(
-            text_piece, level, offset * n_offset, max_value, weights, height, width
+    max_value = weights.sum(0).max()
+    polygons = []
+    start_index = 0
+    for text_piece, confidence in zip(text_list, confidence_list):
+        start_index += len(text_piece) + offset
+        polygon = get_polygon(
+            text_piece, max_value, offset, weights, size=(width, height)
         )
-        for n_offset, text_piece in enumerate(text_list)
-    ]
+        polygon["text"] = text_piece
+        polygon["text_confidence"] = confidence
+        polygons.append(polygon)
+    return polygons
+
+
+def compute_contour_metrics(coverage_vector, contour):
+    """
+    Compute the contours's area and the mean value inside it.
+    :param coverage_vector: Aggregated attention weights of the current text piece, resized to image. size: (n_char, image_height, image_width)
+    :param contour: Contour of the current attention blob
+    """
+    # draw the contour zone
+    mask = np.zeros(coverage_vector.shape, dtype=np.uint8)
+    cv2.drawContours(mask, [contour], -1, (255), -1)
+
+    max_value = (
+        np.where(mask > 0, coverage_vector, 0).max() / 255
+    )  # cv2.max(coverage_vector, mask=mask)[0] / 255.
+    area = cv2.contourArea(contour)
+    return max_value, max_value * area
 
 
-def get_polygon(text_piece, level, offset, max_value, weights, height, width):
+def get_polygon(text, max_value, offset, weights, size=None, return_contours=False):
     """
     Gets polygon associated with element of current text_piece, indexed by offset
-    :param text_piece: Current text element
-    :param level: Level to display (must be in [char, word, line])
-    :param offset: Offset value to get the relevant part of text piece
+    :param text: Text piece selected with offset after splitting DAN prediction
     :param max_value: Maximum "attention intensity" for parts of a text piece, used for normalization
-    :param weights: Attention weights of size (n_char, feature_height, feature_width)
-    :param height: Original image height
-    :param width: Original image width
+    :param offset: Offset value to get the relevant part of text piece
+    :param attentions: Attention weights of size (n_char, feature_height, feature_width)
+    :param size: Target size (width, height) to resize the coverage vector
+    :param return_contours: Return the contour of the current polygon (used for plotting)
     """
-    coverage_vector = compute_coverage(text_piece, max_value, offset, weights)
-    coverage_vector = cv2.resize(coverage_vector, (width, height))
+    # Compute coverage vector
+    coverage_vector = compute_coverage(text, max_value, offset, weights, size=size)
 
     # Generate a binary image for the current channel.
-    bin_img = coverage_vector.copy()
-    bin_img[bin_img > 0] = 1
+    bin_mask = np.array(np.where(coverage_vector > 5, 255, 0), dtype=np.uint8)
+    bin_mask = np.asarray(bin_mask, dtype=np.uint8)
+
+    # Detect the objects contours
+    contours, _ = cv2.findContours(bin_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+    # Select best contour
+    metrics = [compute_contour_metrics(coverage_vector, cnt) for cnt in contours]
+    confidences, scores = map(list, zip(*metrics))
+    best_contour = contours[np.argmax(scores)]
+    confidence = round(confidences[np.argmax(scores)] / max_value, 2)
+
+    # Format for JSON
+    polygon = {
+        "confidence": confidence,
+        "polygon": [coordinates[0].tolist() for coordinates in best_contour],
+    }
 
-    # Detect the objects contours.
-    contours, _ = cv2.findContours(
-        np.uint8(bin_img), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
-    )
+    if return_contours:
+        return polygon, best_contour
 
-    mask = np.zeros(coverage_vector.shape)
-    cv2.drawContours(mask, contours, -1, 1, -1)
-    confidence = round((np.sum(mask * coverage_vector) / np.sum(mask)), 2)
-
-    # Put together all contours for now.
-    pre_contours_tojson = [[item.tolist() for item in contours]]
-    # Quick hack to have better json format:
-    contours_tojson = []
-    for contour in pre_contours_tojson[0]:
-        for coordinate in contour:
-            contours_tojson.append(coordinate[0])
-
-    return {
-        "confidence": confidence,  # average of coverage vector on contours
-        "polygon": contours_tojson,
-        "type": level,
-    }
+    return polygon
 
 
 def plot_attention(
@@ -155,7 +234,7 @@ def plot_attention(
     outname,
     word_separators=["\n", " "],
     line_separators=["\n"],
-    output_polygons=False,
+    display_polygons=False,
 ):
     """
     Create a gif by blending attention maps to the image for each text piece (char, word or line)
@@ -165,11 +244,13 @@ def plot_attention(
     :param level: Level to display (must be in [char, word, line])
     :param scale: Scaling factor for the output gif image
     :param outname: Name of the gif image
+    :param word_separators: List of word separators
+    :param line_separators: List of line separators
+    :param output_polygons: Whether to plot extracted polygons
     """
 
     height, width, _ = image.shape
     attention_map = []
-    polygons = []
 
     # Convert to PIL Image and create mask
     mask = Image.new("L", (width, height), color=(110))
@@ -180,21 +261,26 @@ def plot_attention(
 
     # Iterate on characters, words or lines
     tot_len = 0
-
     max_value = weights.sum(0).max()
 
     for text_piece in text_list:
         # Accumulate weights for the current word/line and resize to original image size
-        coverage_vector = compute_coverage(text_piece, max_value, tot_len, weights)
-        coverage_vector = cv2.resize(coverage_vector, (width, height))
+        coverage_vector = compute_coverage(
+            text_piece, max_value, tot_len, weights, (width, height)
+        )
 
         # Get polygons if flag is set:
-        if output_polygons:
-            polygons.append(
-                get_polygon(
-                    text_piece, level, tot_len, max_value, weights, height, width
-                )
+        if display_polygons:
+            # draw the contour
+            _, contour = get_polygon(
+                text_piece,
+                max_value,
+                tot_len,
+                weights,
+                (width, height),
+                return_contours=True,
             )
+            cv2.drawContours(coverage_vector, [contour], 0, (255), 3)
 
         # Keep track of text length
         tot_len += len(text_piece) + offset
@@ -210,5 +296,3 @@ def plot_attention(
         duration=1000,
         loop=True,
     )
-
-    return polygons
diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index 7cd32694..09e7a7ba 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -2,7 +2,6 @@
 
 import os
 import pickle
-import re
 
 import cv2
 import numpy as np
@@ -14,8 +13,13 @@ from dan.datasets.extract.utils import save_json
 from dan.decoder import GlobalHTADecoder
 from dan.models import FCN_Encoder
 from dan.ocr.utils import LM_ind_to_str
-from dan.predict.attention import get_predicted_polygons, plot_attention
-from dan.utils import read_image, round_floats
+from dan.predict.attention import (
+    get_predicted_polygons_with_confidence,
+    parse_delimiters,
+    plot_attention,
+    split_text_and_confidences,
+)
+from dan.utils import read_image
 
 
 class DAN:
@@ -93,6 +97,7 @@ class DAN:
         confidences=False,
         attentions=False,
         attention_level=False,
+        extract_objects=False,
         word_separators=["\n", " "],
         line_separators=["\n"],
     ):
@@ -113,13 +118,20 @@ class DAN:
 
         # Run the prediction.
         with torch.no_grad():
-            b = input_tensor.size(0)
-            reached_end = torch.zeros((b,), dtype=torch.bool, device=self.device)
-            prediction_len = torch.zeros((b,), dtype=torch.int, device=self.device)
+            batch_size = input_tensor.size(0)
+            reached_end = torch.zeros(
+                (batch_size,), dtype=torch.bool, device=self.device
+            )
+            prediction_len = torch.zeros(
+                (batch_size,), dtype=torch.int, device=self.device
+            )
             predicted_tokens = (
-                torch.ones((b, 1), dtype=torch.long, device=self.device) * start_token
+                torch.ones((batch_size, 1), dtype=torch.long, device=self.device)
+                * start_token
+            )
+            predicted_tokens_len = torch.ones(
+                (batch_size,), dtype=torch.int, device=self.device
             )
-            predicted_tokens_len = torch.ones((b,), dtype=torch.int, device=self.device)
 
             whole_output = list()
             confidence_scores = list()
@@ -188,10 +200,11 @@ class DAN:
             predicted_tokens = predicted_tokens[:, 1:]
             prediction_len[torch.eq(reached_end, False)] = self.max_chars - 1
             predicted_tokens = [
-                predicted_tokens[i, : prediction_len[i]] for i in range(b)
+                predicted_tokens[i, : prediction_len[i]] for i in range(batch_size)
             ]
             confidence_scores = [
-                confidence_scores[i, : prediction_len[i]].tolist() for i in range(b)
+                confidence_scores[i, : prediction_len[i]].tolist()
+                for i in range(batch_size)
             ]
 
             # Transform tokens to characters
@@ -201,44 +214,30 @@ class DAN:
 
             logger.info("Images processed")
 
-        out = {"text": predicted_text}
+        out = {}
+
+        out["text"] = predicted_text
         if confidences:
             out["confidences"] = confidence_scores
         if attentions:
             out["attentions"] = attention_maps
-            # Also get information on polygons
-            out["objects"] = get_predicted_polygons(
-                predicted_text[0],
-                attention_maps[0],
-                attention_level,
-                input_sizes[0][0],
-                input_sizes[0][1],
-                word_separators,
-                line_separators,
-            )
+        if extract_objects:
+            out["objects"] = [
+                get_predicted_polygons_with_confidence(
+                    predicted_text[i],
+                    attention_maps[i],
+                    confidence_scores[i],
+                    attention_level,
+                    input_sizes[i][0],
+                    input_sizes[i][1],
+                    word_separators,
+                    line_separators,
+                )
+                for i in range(batch_size)
+            ]
         return out
 
 
-def parse_delimiters(delimiters):
-    return re.compile(r"|".join(delimiters))
-
-
-def compute_prob_by_separator(characters, probabilities, separator):
-    """
-    Split text and confidences using separators and return a list of average confidence scores.
-    :param characters: list of characters.
-    :param probabilities: list of probabilities.
-    :param separators: regex for separators. Use parse_delimiters(["\n", " "]) for word confidences and parse_delimiters(["\n"]) for line confidences.
-    Returns a list confidence scores.
-    """
-    # match anything except separators, get start and end index
-    pattern = re.compile(f"[^{separator.pattern}]+")
-    matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)]
-
-    # Iterate over text pieces and compute mean confidence
-    return [np.mean(probabilities[start:end]) for (start, end) in matches]
-
-
 def run(
     image,
     model,
@@ -255,6 +254,22 @@ def run(
     line_separators,
     predict_objects,
 ):
+    """
+    Predict a single image save the output
+    :param image: Path to the image to predict.
+    :param model: Path to the model to use for prediction.
+    :param parameters: Path to the YAML parameters file.
+    :param charset: Path to the charset.
+    :param output: Path to the output folder where the results will be saved.
+    :param scale: Scaling factor to resize the image.
+    :param confidence_score: Whether to compute confidence score.
+    :param attention_map: Whether to plot the attention map.
+    :param attention_map_level: Level of objects to extract.
+    :param attention_map_scale: Scaling factor for the attention map.
+    :param word_separators: List of word separators.
+    :param line_separators: List of line separators.
+    :param predict_objects: Whether to extract objects.
+    """
     # Create output directory if necessary
     if not os.path.exists(output):
         os.mkdir(output)
@@ -286,30 +301,36 @@ def run(
         confidences=confidence_score,
         attentions=attention_map,
         attention_level=attention_map_level,
+        extract_objects=predict_objects,
         word_separators=word_separators,
         line_separators=line_separators,
     )
-    text = prediction["text"][0]
-    result = {"text": text}
+    result = {}
+    result["text"] = prediction["text"][0]
 
-    result["objects"] = prediction["objects"]
+    # Return extracted objects (coordinates, text, confidence)
+    if predict_objects:
+        result["objects"] = prediction["objects"][0]
 
-    # Average character-based confidence scores
+    # Return mean confidence score
     if confidence_score:
+        result["confidences"] = {}
+
         char_confidences = prediction["confidences"][0]
-        result["confidences"] = {"total": np.around(np.mean(char_confidences), 2)}
-        if "word" in confidence_score_levels:
-            word_probs = compute_prob_by_separator(
-                text, char_confidences, word_separators
-            )
-            result["confidences"].update({"word": round_floats(word_probs)})
-        if "line" in confidence_score_levels:
-            line_probs = compute_prob_by_separator(
-                text, char_confidences, line_separators
+        result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
+
+        for level in confidence_score_levels:
+            result["confidences"][level] = []
+            texts, confidences, _ = split_text_and_confidences(
+                prediction["text"][0],
+                char_confidences,
+                level,
+                word_separators,
+                line_separators,
             )
-            result["confidences"].update({"line": round_floats(line_probs)})
-        if "char" in confidence_score_levels:
-            result["confidences"].update({"char": round_floats(char_confidences)})
+
+            for text, conf in zip(texts, confidences):
+                result["confidences"][level].append({"text": text, "confidence": conf})
 
     # Save gif with attention map
     if attention_map:
@@ -324,7 +345,7 @@ def run(
             scale=attention_map_scale,
             word_separators=word_separators,
             line_separators=line_separators,
-            output_polygons=predict_objects,
+            display_polygons=predict_objects,
             outname=gif_filename,
         )
         result["attention_gif"] = gif_filename
-- 
GitLab