Skip to content
Snippets Groups Projects
Commit f5e6b784 authored by Tristan Faine's avatar Tristan Faine Committed by Solene Tarride
Browse files

Added function to aggregate attention maps

parent 614fa206
No related branches found
No related tags found
1 merge request!76Add predicted objects to predict command
......@@ -3,6 +3,7 @@ import re
import cv2
import numpy as np
import math
from PIL import Image
from dan import logger
......@@ -39,12 +40,14 @@ def compute_coverage(text: str, max_value: float, offset: int, attentions):
:param offset: Offset value to get the relevant part of text piece
:param attentions: Attention weights of size (n_char, feature_height, feature_width)
"""
_, height, width = attentions.shape
height = attentions.shape[1]
width = attentions.shape[2]
# blank vector to accumulate weights for the current text
coverage_vector = np.zeros((height, width))
for i in range(len(text)):
local_weight = cv2.resize(attentions[i + offset], (width, height))
local_weight = attentions[i + offset]
local_weight = cv2.resize(local_weight, (width, height))
coverage_vector = np.clip(coverage_vector + local_weight, 0, 1)
# Normalize coverage vector
......@@ -71,8 +74,8 @@ def plot_attention(
:param scale: Scaling factor for the output gif image
:param outname: Name of the gif image
"""
height, width, _ = image.shape
height, width, _ = image.shape
attention_map = []
# Convert to PIL Image and create mask
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment