From f5e6b7849a3645c533c303a32d33f6781332409b Mon Sep 17 00:00:00 2001
From: Tristan Faine <tfaine@teklia.com>
Date: Tue, 7 Mar 2023 17:05:35 +0100
Subject: [PATCH] Added function to aggregate attention maps

---
 dan/predict/attention.py | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/dan/predict/attention.py b/dan/predict/attention.py
index 3d32b5c8..09229ec0 100644
--- a/dan/predict/attention.py
+++ b/dan/predict/attention.py
@@ -3,6 +3,7 @@ import re
 
 import cv2
 import numpy as np
+import math
 from PIL import Image
 
 from dan import logger
@@ -39,12 +40,14 @@ def compute_coverage(text: str, max_value: float, offset: int, attentions):
     :param offset: Offset value to get the relevant part of text piece
     :param attentions: Attention weights of size (n_char, feature_height, feature_width)
     """
-    _, height, width = attentions.shape
+    height = attentions.shape[1]
+    width = attentions.shape[2]
 
     # blank vector to accumulate weights for the current text
     coverage_vector = np.zeros((height, width))
     for i in range(len(text)):
-        local_weight = cv2.resize(attentions[i + offset], (width, height))
+        local_weight = attentions[i + offset]
+        local_weight = cv2.resize(local_weight, (width, height))
         coverage_vector = np.clip(coverage_vector + local_weight, 0, 1)
 
     # Normalize coverage vector
@@ -71,8 +74,8 @@ def plot_attention(
     :param scale: Scaling factor for the output gif image
     :param outname: Name of the gif image
     """
-
-    height, width, _ = image.shape
+    
+    height, width, _ = image.shape 
     attention_map = []
 
     # Convert to PIL Image and create mask
-- 
GitLab