From 774887b1a35558639ddb441be7365f19e18896de Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Thu, 20 Jul 2023 11:49:48 +0000
Subject: [PATCH] Batch predictions

---
 dan/predict/__init__.py   |   8 ++
 dan/predict/attention.py  |  14 ++--
 dan/predict/prediction.py | 161 ++++++++++++++++++++----------------
 dan/utils.py              |  12 +++
 tests/test_prediction.py  | 170 +++++++++++++++++++++++++++++++++++++-
 5 files changed, 282 insertions(+), 83 deletions(-)

diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py
index 1b73e3b8..2929f4db 100644
--- a/dan/predict/__init__.py
+++ b/dan/predict/__init__.py
@@ -140,4 +140,12 @@ def add_predict_parser(subcommands) -> None:
         type=int,
         required=False,
     )
+    parser.add_argument(
+        "--batch-size",
+        help="Size of prediction batches.",
+        type=int,
+        default=1,
+        required=False,
+    )
+
     parser.set_defaults(func=run)
diff --git a/dan/predict/attention.py b/dan/predict/attention.py
index 0b06425f..dc8c9397 100644
--- a/dan/predict/attention.py
+++ b/dan/predict/attention.py
@@ -4,6 +4,7 @@ import re
 import cv2
 import numpy as np
 from PIL import Image
+from torchvision.transforms.functional import to_pil_image
 
 from dan import logger
 
@@ -179,7 +180,7 @@ def blend_coverage(coverage_vector, image, mask, scale):
     blend = Image.composite(image, coverage_vector, mask)
 
     # Resize to save time
-    blend = blend.resize((int(width * scale), int(height * scale)), Image.ANTIALIAS)
+    blend = blend.resize((int(width * scale), int(height * scale)), Image.LANCZOS)
     return blend
 
 
@@ -292,7 +293,7 @@ def plot_attention(
 ):
     """
     Create a gif by blending attention maps to the image for each text piece (char, word or line)
-    :param image: Input image in PIL format
+    :param image: Input image as torch.Tensor
     :param text: Text predicted by DAN
     :param weights: Attention weights of size (n_char, feature_height, feature_width)
     :param level: Level to display (must be in [char, word, line])
@@ -303,12 +304,11 @@ def plot_attention(
     :param display_polygons: Whether to plot extracted polygons
     """
 
-    height, width, _ = image.shape
+    image = to_pil_image(image)
     attention_map = []
 
     # Convert to PIL Image and create mask
-    mask = Image.new("L", (width, height), color=(110))
-    image = Image.fromarray(image)
+    mask = Image.new("L", (image.width, image.height), color=(110))
 
     # Split text into characters, words or lines
     text_list, offset = split_text(text, level, word_separators, line_separators)
@@ -320,7 +320,7 @@ def plot_attention(
     for text_piece in text_list:
         # Accumulate weights for the current word/line and resize to original image size
         coverage_vector = compute_coverage(
-            text_piece, max_value, tot_len, weights, (width, height)
+            text_piece, max_value, tot_len, weights, (image.width, image.height)
         )
 
         # Get polygons if flag is set:
@@ -333,7 +333,7 @@ def plot_attention(
                 weights,
                 threshold_method=threshold_method,
                 threshold_value=threshold_value,
-                size=(width, height),
+                size=(image.width, image.height),
             )
 
             if contour is not None:
diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index ff87ca3e..3d144d82 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -20,7 +20,7 @@ from dan.predict.attention import (
     split_text_and_confidences,
 )
 from dan.transforms import get_normalization_transforms, get_preprocessing_transforms
-from dan.utils import ind_to_token, read_image
+from dan.utils import ind_to_token, list_to_batches, pad_images, read_image
 
 
 class DAN:
@@ -248,8 +248,8 @@ class DAN:
         return out
 
 
-def process_image(
-    image_path,
+def process_batch(
+    image_batch,
     dan_model,
     device,
     output,
@@ -264,20 +264,25 @@ def process_image(
     threshold_method,
     threshold_value,
 ):
-    # Load image and pre-process it
-    image = dan_model.preprocess(str(image_path))
-    logger.info("Image loaded.")
+    input_images, input_sizes = [], []
+    logger.info("Loading images...")
+    for image_path in image_batch:
+        # Load image and pre-process it
+        image = dan_model.preprocess(str(image_path))
+        input_images.append(image)
+        input_sizes.append(image.shape[1:])
 
     # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
-    input_tensor = image.unsqueeze(0)
-    input_tensor = input_tensor.to(device)
-    input_sizes = [image.shape[1:]]
+    input_tensor = pad_images(input_images).to(device)
+
+    logger.info("Images preprocessed!")
 
     # Parse delimiters to regex
     word_separators = parse_delimiters(word_separators)
     line_separators = parse_delimiters(line_separators)
 
     # Predict
+    logger.info("Predicting...")
     prediction = dan_model.predict(
         input_tensor,
         input_sizes,
@@ -290,70 +295,78 @@ def process_image(
         threshold_method=threshold_method,
         threshold_value=threshold_value,
     )
+    logger.info("Prediction parsing...")
+
+    for idx, image_path in enumerate(image_batch):
+        predicted_text = prediction["text"][idx]
+        result = {"text": predicted_text}
+
+        # Return extracted objects (coordinates, text, confidence)
+        if predict_objects:
+            result["objects"] = prediction["objects"][idx]
+
+        # Return mean confidence score
+        if confidence_score:
+            result["confidences"] = {}
+            char_confidences = prediction["confidences"][idx]
+            # retrieve the index of the token ner
+            index = [
+                pos
+                for pos, char in enumerate(predicted_text)
+                if char in ["ⓝ", "ⓟ", "ⓓ", "ⓡ"]
+            ]
 
-    result = {}
-    result["text"] = prediction["text"][0]
-
-    # Return extracted objects (coordinates, text, confidence)
-    if predict_objects:
-        result["objects"] = prediction["objects"][0]
-
-    # Return mean confidence score
-    if confidence_score:
-        result["confidences"] = {}
-        char_confidences = prediction["confidences"][0]
-        text = result["text"]
-        # retrieve the index of the token ner
-        index = [pos for pos, char in enumerate(text) if char in ["ⓝ", "ⓟ", "ⓓ", "ⓡ"]]
-
-        # calculates scores by token
-
-        result["confidences"]["by ner token"] = [
-            {
-                "text": f"{text[current: next_token]}".replace("\n", " "),
-                "confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token]), 2)}",
-            }
-            # We go up to -1 so that the last token matches until the end of the text
-            for current, next_token in pairwise(index + [-1])
-        ]
-        result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
-
-        for level in confidence_score_levels:
-            result["confidences"][level] = []
-            texts, confidences, _ = split_text_and_confidences(
-                prediction["text"][0],
-                char_confidences,
-                level,
-                word_separators,
-                line_separators,
-            )
+            # calculates scores by token
 
-            for text, conf in zip(texts, confidences):
-                result["confidences"][level].append({"text": text, "confidence": conf})
-
-    # Save gif with attention map
-    if attention_map:
-        gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif"
-        logger.info(f"Creating attention GIF in {gif_filename}")
-        # this returns polygons but unused for now.
-        plot_attention(
-            image=image,
-            text=prediction["text"][0],
-            weights=prediction["attentions"][0],
-            level=attention_map_level,
-            scale=attention_map_scale,
-            word_separators=word_separators,
-            line_separators=line_separators,
-            display_polygons=predict_objects,
-            threshold_method=threshold_method,
-            threshold_value=threshold_value,
-            outname=gif_filename,
-        )
-        result["attention_gif"] = gif_filename
+            result["confidences"]["by ner token"] = [
+                {
+                    "text": f"{predicted_text[current: next_token]}".replace("\n", " "),
+                    "confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token]), 2)}",
+                }
+                # We go up to -1 so that the last token matches until the end of the text
+                for current, next_token in pairwise(index + [-1])
+            ]
+            result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
+
+            for level in confidence_score_levels:
+                result["confidences"][level] = []
+                texts, confidences, _ = split_text_and_confidences(
+                    predicted_text,
+                    char_confidences,
+                    level,
+                    word_separators,
+                    line_separators,
+                )
+
+                for text, conf in zip(texts, confidences):
+                    result["confidences"][level].append(
+                        {"text": text, "confidence": conf}
+                    )
+
+        # Save gif with attention map
+        if attention_map:
+            attentions = prediction["attentions"][idx]
+            gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif"
+            logger.info(f"Creating attention GIF in {gif_filename}")
+            # this returns polygons but unused for now.
+            plot_attention(
+                image=input_tensor[idx],
+                text=predicted_text,
+                weights=attentions,
+                level=attention_map_level,
+                scale=attention_map_scale,
+                word_separators=word_separators,
+                line_separators=line_separators,
+                display_polygons=predict_objects,
+                threshold_method=threshold_method,
+                threshold_value=threshold_value,
+                outname=gif_filename,
+            )
+            result["attention_gif"] = gif_filename
 
-    json_filename = f"{output}/{image_path.stem}.json"
-    logger.info(f"Saving JSON prediction in {json_filename}")
-    save_json(Path(json_filename), result)
+        json_filename = f"{output}/{image_path.stem}.json"
+        logger.info(f"Saving JSON prediction in {json_filename}")
+        save_json(Path(json_filename), result)
 
 
 def run(
@@ -376,6 +389,7 @@ def run(
     threshold_value,
     image_extension,
     gpu_device,
+    batch_size,
 ):
     """
     Predict a single image save the output
@@ -395,6 +409,7 @@ def run(
     :param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
     :param threshold_value: Thresholding value to use for the "simple" thresholding method.
     :param gpu_device: Use a specific GPU if available.
+    :param batch_size: Size of the batches for prediction.
     """
     # Create output directory if necessary
     if not os.path.exists(output):
@@ -407,9 +422,9 @@ def run(
     dan_model.load(model, parameters, charset, mode="eval")
 
     images = image_dir.rglob(f"*{image_extension}") if not image else [image]
-    for image_name in images:
-        process_image(
-            image_name,
+    for image_batch in list_to_batches(images, n=batch_size):
+        process_batch(
+            image_batch,
             dan_model,
             device,
             output,
diff --git a/dan/utils.py b/dan/utils.py
index 1ec9a8cf..9eddb9c5 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -1,4 +1,6 @@
 # -*- coding: utf-8 -*-
+from itertools import islice
+
 import torch
 import torchvision.io as torchvision
 
@@ -72,3 +74,13 @@ def ind_to_token(labels, ind, oov_symbol=None):
     else:
         res = [labels[i] for i in ind]
     return "".join(res)
+
+
+def list_to_batches(iterable, n):
+    "Batch data into tuples of length n. The last batch may be shorter."
+    # list_to_batches('ABCDEFG', 3) --> ABC DEF G
+    if n < 1:
+        raise ValueError("n must be at least one")
+    it = iter(iterable)
+    while batch := tuple(islice(it, n)):
+        yield batch
diff --git a/tests/test_prediction.py b/tests/test_prediction.py
index db87cdb4..9d7409c4 100644
--- a/tests/test_prediction.py
+++ b/tests/test_prediction.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 
 import json
+import shutil
 
 import pytest
 
@@ -270,9 +271,172 @@ def test_run_prediction(
         threshold_value=0,
         image_extension=None,
         gpu_device=None,
+        batch_size=1,
     )
 
-    with (tmp_path / image_name).with_suffix(".json").open("r") as json_file:
-        prediction = json.load(json_file)
-
+    prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text())
     assert prediction == expected_prediction
+
+
+@pytest.mark.parametrize(
+    "image_names, confidence_score, temperature, expected_predictions",
+    (
+        (
+            ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
+            None,
+            1.0,
+            [{"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"}],
+        ),
+        (
+            ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
+            ["word"],
+            1.0,
+            [
+                {
+                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                    "confidences": {
+                        "by ner token": [],
+                        "total": 1.0,
+                        "word": [
+                            {"text": "ⓈBellisson", "confidence": 1.0},
+                            {"text": "â’»Georges", "confidence": 1.0},
+                            {"text": "â’·91", "confidence": 1.0},
+                            {"text": "ⓁP", "confidence": 1.0},
+                            {"text": "â’¸M", "confidence": 1.0},
+                            {"text": "â“€Ch", "confidence": 1.0},
+                            {"text": "â“„Plombier", "confidence": 1.0},
+                            {"text": "â“…Patron?12241", "confidence": 1.0},
+                        ],
+                    },
+                }
+            ],
+        ),
+        (
+            [
+                "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
+                "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
+            ],
+            ["word"],
+            1.0,
+            [
+                {
+                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                    "confidences": {
+                        "by ner token": [],
+                        "total": 1.0,
+                        "word": [
+                            {"text": "ⓈBellisson", "confidence": 1.0},
+                            {"text": "â’»Georges", "confidence": 1.0},
+                            {"text": "â’·91", "confidence": 1.0},
+                            {"text": "ⓁP", "confidence": 1.0},
+                            {"text": "â’¸M", "confidence": 1.0},
+                            {"text": "â“€Ch", "confidence": 1.0},
+                            {"text": "â“„Plombier", "confidence": 1.0},
+                            {"text": "â“…Patron?12241", "confidence": 1.0},
+                        ],
+                    },
+                },
+                {
+                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                    "confidences": {
+                        "by ner token": [],
+                        "total": 1.0,
+                        "word": [
+                            {"text": "ⓈBellisson", "confidence": 1.0},
+                            {"text": "â’»Georges", "confidence": 1.0},
+                            {"text": "â’·91", "confidence": 1.0},
+                            {"text": "ⓁP", "confidence": 1.0},
+                            {"text": "â’¸M", "confidence": 1.0},
+                            {"text": "â“€Ch", "confidence": 1.0},
+                            {"text": "â“„Plombier", "confidence": 1.0},
+                            {"text": "â“…Patron?12241", "confidence": 1.0},
+                        ],
+                    },
+                },
+            ],
+        ),
+        (
+            ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
+            ["word"],
+            1.0,
+            [
+                {
+                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
+                    "confidences": {
+                        "by ner token": [],
+                        "total": 1.0,
+                        "word": [
+                            {"text": "ⓈBellisson", "confidence": 1.0},
+                            {"text": "â’»Georges", "confidence": 1.0},
+                            {"text": "â’·91", "confidence": 1.0},
+                            {"text": "ⓁP", "confidence": 1.0},
+                            {"text": "â’¸M", "confidence": 1.0},
+                            {"text": "â“€Ch", "confidence": 1.0},
+                            {"text": "â“„Plombier", "confidence": 1.0},
+                            {"text": "â“…Patron?12241", "confidence": 1.0},
+                        ],
+                    },
+                }
+            ],
+        ),
+        (
+            [
+                "2c242f5c-e979-43c4-b6f2-a6d4815b651d",
+                "ffdec445-7f14-4f5f-be44-68d0844d0df1",
+            ],
+            False,
+            1.0,
+            [
+                {"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"},
+                {"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
+            ],
+        ),
+    ),
+)
+@pytest.mark.parametrize("batch_size", [1, 2])
+def test_run_prediction_batch(
+    image_names,
+    confidence_score,
+    temperature,
+    expected_predictions,
+    prediction_data_path,
+    batch_size,
+    tmp_path,
+):
+    # Make tmpdir and copy needed images inside
+    image_dir = tmp_path / "images"
+    image_dir.mkdir()
+    for image_name in image_names:
+        shutil.copyfile(
+            (prediction_data_path / "images" / image_name).with_suffix(".png"),
+            (image_dir / image_name).with_suffix(".png"),
+        )
+
+    run_prediction(
+        image=None,
+        image_dir=image_dir,
+        model=prediction_data_path / "popp_line_model.pt",
+        parameters=prediction_data_path / "parameters.yml",
+        charset=prediction_data_path / "charset.pkl",
+        output=tmp_path,
+        confidence_score=True if confidence_score else False,
+        confidence_score_levels=confidence_score if confidence_score else [],
+        attention_map=False,
+        attention_map_level=None,
+        attention_map_scale=0.5,
+        word_separators=[" ", "\n"],
+        line_separators=["\n"],
+        temperature=temperature,
+        predict_objects=False,
+        threshold_method="otsu",
+        threshold_value=0,
+        image_extension=".png",
+        gpu_device=None,
+        batch_size=batch_size,
+    )
+
+    for image_name, expected_prediction in zip(image_names, expected_predictions):
+        prediction = json.loads(
+            (tmp_path / image_name).with_suffix(".json").read_text()
+        )
+        assert prediction == expected_prediction
-- 
GitLab