From bac1de3c8e87c1354ae6fe85c6ea0b6831158692 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Thu, 3 Aug 2023 12:39:01 +0200
Subject: [PATCH] Apply 1956acc9

---
 dan/predict/prediction.py | 17 ++++++++++-------
 tests/test_prediction.py  |  2 +-
 2 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index 9104aa8f..f2e2ee16 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -90,11 +90,12 @@ class DAN:
         """
         image = read_image(path)
         preprocessed_image = self.preprocessing_transforms(image)
+        normalized_image = torch.zeros(preprocessed_image.shape)
         for ch in range(preprocessed_image.shape[0]):
-            preprocessed_image[ch, :, :] = (
+            normalized_image[ch, :, :] = (
                 preprocessed_image[ch, :, :] - self.mean[ch]
             ) / self.std[ch]
-        return preprocessed_image
+        return preprocessed_image, normalized_image
 
     def predict(
         self,
@@ -271,16 +272,18 @@ def process_batch(
     threshold_method,
     threshold_value,
 ):
-    input_images, input_sizes = [], []
+    input_images, visu_images, input_sizes = [], [], []
     logger.info("Loading images...")
     for image_path in image_batch:
         # Load image and pre-process it
-        image = dan_model.preprocess(str(image_path))
-        input_images.append(image)
-        input_sizes.append(image.shape[1:])
+        visu_image, input_image = dan_model.preprocess(str(image_path))
+        input_images.append(input_image)
+        visu_images.append(visu_image)
+        input_sizes.append(input_image.shape[1:])
 
     # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
     input_tensor = pad_images(input_images).to(device)
+    visu_tensor = pad_images(visu_images).to(device)
     logger.info("Images preprocessed!")
 
     # Parse delimiters to regex
@@ -355,7 +358,7 @@ def process_batch(
             logger.info(f"Creating attention GIF in {gif_filename}")
             # this returns polygons but unused for now.
             plot_attention(
-                image=input_tensor[idx],
+                image=visu_tensor[idx],
                 text=predicted_text,
                 weights=attentions,
                 level=attention_map_level,
diff --git a/tests/test_prediction.py b/tests/test_prediction.py
index fad5d04d..58972b1b 100644
--- a/tests/test_prediction.py
+++ b/tests/test_prediction.py
@@ -46,7 +46,7 @@ def test_predict(
     )
 
     image_path = prediction_data_path / "images" / image_name
-    image = dan_model.preprocess(str(image_path))
+    _, image = dan_model.preprocess(str(image_path))
 
     input_tensor = image.unsqueeze(0)
     input_tensor = input_tensor.to(device)
-- 
GitLab