Skip to content
Snippets Groups Projects

Added function to aggregate attention maps

Merged Thibault Lavigne requested to merge 35-new-functon-to-aggregate-attention-maps-2 into main
All threads resolved!
1 file
+ 26
9
Compare changes
  • Side-by-side
  • Inline
+ 26
9
@@ -31,6 +31,27 @@ def split_text(text, level, word_separators, line_separators):
return text_split, offset
def compute_coverage(text: str, max_value: float, offset: int, attentions):
"""
Aggregates attention maps for the current text piece (char, word, line)
:param text: Text piece selected with offset after splitting DAN prediction
:param max_value: Maximum "attention intensity" for parts of a text piece, used for normalization
:param offset: Offset value to get the relevant part of text piece
:param attentions: Attention weights of size (n_char, feature_height, feature_width)
"""
_, height, width = attentions.shape
# blank vector to accumulate weights for the current text
coverage_vector = np.zeros((height, width))
for i in range(len(text)):
local_weight = cv2.resize(attentions[i + offset], (width, height))
coverage_vector = np.clip(coverage_vector + local_weight, 0, 1)
# Normalize coverage vector
coverage_vector = (coverage_vector / max_value * 255).astype(np.uint8)
return coverage_vector
def plot_attention(
image,
text,
@@ -50,6 +71,7 @@ def plot_attention(
:param scale: Scaling factor for the output gif image
:param outname: Name of the gif image
"""
height, width, _ = image.shape
attention_map = []
@@ -64,20 +86,15 @@ def plot_attention(
tot_len = 0
max_value = weights.sum(0).max()
for text_piece in text_list:
# blank vector to accumulate weights for the current word/line
coverage_vector = np.zeros((height, width))
for i in range(len(text_piece)):
local_weight = weights[i + tot_len]
local_weight = cv2.resize(local_weight, (width, height))
coverage_vector = np.clip(coverage_vector + local_weight, 0, 1)
# Accumulate weights for the current word/line and resize to original image size
coverage_vector = compute_coverage(text_piece, max_value, tot_len, weights)
coverage_vector = cv2.resize(coverage_vector, (width, height))
# Keep track of text length
tot_len += len(text_piece) + offset
# Normalize coverage vector
coverage_vector = (coverage_vector / max_value * 255).astype(np.uint8)
# Blend coverage vector with original image
blank_array = np.zeros((height, width)).astype(np.uint8)
coverage_vector = Image.fromarray(
Loading