From f5e6b7849a3645c533c303a32d33f6781332409b Mon Sep 17 00:00:00 2001 From: Tristan Faine <tfaine@teklia.com> Date: Tue, 7 Mar 2023 17:05:35 +0100 Subject: [PATCH] Added function to aggregate attention maps --- dan/predict/attention.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dan/predict/attention.py b/dan/predict/attention.py index 3d32b5c8..09229ec0 100644 --- a/dan/predict/attention.py +++ b/dan/predict/attention.py @@ -3,6 +3,7 @@ import re import cv2 import numpy as np +import math from PIL import Image from dan import logger @@ -39,12 +40,14 @@ def compute_coverage(text: str, max_value: float, offset: int, attentions): :param offset: Offset value to get the relevant part of text piece :param attentions: Attention weights of size (n_char, feature_height, feature_width) """ - _, height, width = attentions.shape + height = attentions.shape[1] + width = attentions.shape[2] # blank vector to accumulate weights for the current text coverage_vector = np.zeros((height, width)) for i in range(len(text)): - local_weight = cv2.resize(attentions[i + offset], (width, height)) + local_weight = attentions[i + offset] + local_weight = cv2.resize(local_weight, (width, height)) coverage_vector = np.clip(coverage_vector + local_weight, 0, 1) # Normalize coverage vector @@ -71,8 +74,8 @@ def plot_attention( :param scale: Scaling factor for the output gif image :param outname: Name of the gif image """ - - height, width, _ = image.shape + + height, width, _ = image.shape attention_map = [] # Convert to PIL Image and create mask -- GitLab