Skip to content
Snippets Groups Projects
Commit fef52f73 authored by Tristan Faine's avatar Tristan Faine Committed by Yoann Schneider
Browse files

Added function to aggregate attention maps

parent 6c4774d5
No related branches found
No related tags found
1 merge request!75Added function to aggregate attention maps
......@@ -31,6 +31,27 @@ def split_text(text, level, word_separators, line_separators):
return text_split, offset
def compute_coverage(text: str, max_value: float, offset: int, attentions):
"""
Aggregates attention maps for the current text piece (char, word, line)
:param text: Text piece selected with offset after splitting DAN prediction
:param max_value: Maximum "attention intensity" for parts of a text piece, used for normalization
:param offset: Offset value to get the relevant part of text piece
:param attentions: Attention weights of size (n_char, feature_height, feature_width)
"""
_, height, width = attentions.shape
# blank vector to accumulate weights for the current text
coverage_vector = np.zeros((height, width))
for i in range(len(text)):
local_weight = cv2.resize(attentions[i + offset], (width, height))
coverage_vector = np.clip(coverage_vector + local_weight, 0, 1)
# Normalize coverage vector
coverage_vector = (coverage_vector / max_value * 255).astype(np.uint8)
return coverage_vector
def plot_attention(
image,
text,
......@@ -50,6 +71,7 @@ def plot_attention(
:param scale: Scaling factor for the output gif image
:param outname: Name of the gif image
"""
height, width, _ = image.shape
attention_map = []
......@@ -64,20 +86,15 @@ def plot_attention(
tot_len = 0
max_value = weights.sum(0).max()
for text_piece in text_list:
# blank vector to accumulate weights for the current word/line
coverage_vector = np.zeros((height, width))
for i in range(len(text_piece)):
local_weight = weights[i + tot_len]
local_weight = cv2.resize(local_weight, (width, height))
coverage_vector = np.clip(coverage_vector + local_weight, 0, 1)
# Accumulate weights for the current word/line and resize to original image size
coverage_vector = compute_coverage(text_piece, max_value, tot_len, weights)
coverage_vector = cv2.resize(coverage_vector, (width, height))
# Keep track of text length
tot_len += len(text_piece) + offset
# Normalize coverage vector
coverage_vector = (coverage_vector / max_value * 255).astype(np.uint8)
# Blend coverage vector with original image
blank_array = np.zeros((height, width)).astype(np.uint8)
coverage_vector = Image.fromarray(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment