Skip to content
Snippets Groups Projects

Do not normalize tensor for attention map visualization

Merged Solene Tarride requested to merge do-not-normalize-tensor-attention into main
All threads resolved!
2 files
+ 9
7
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 8
6
@@ -87,7 +87,7 @@ class DAN:
"""
image = read_image(path)
preprocessed_image = self.preprocessing_transforms(image)
return self.normalization(preprocessed_image)
return preprocessed_image, self.normalization(preprocessed_image)
def predict(
self,
@@ -264,16 +264,18 @@ def process_batch(
threshold_method,
threshold_value,
):
input_images, input_sizes = [], []
input_images, visu_images, input_sizes = [], [], []
logger.info("Loading images...")
for image_path in image_batch:
# Load image and pre-process it
image = dan_model.preprocess(str(image_path))
input_images.append(image)
input_sizes.append(image.shape[1:])
visu_image, input_image = dan_model.preprocess(str(image_path))
input_images.append(input_image)
visu_images.append(visu_image)
input_sizes.append(input_image.shape[1:])
# Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
input_tensor = pad_images(input_images).to(device)
visu_tensor = pad_images(visu_images).to(device)
logger.info("Images preprocessed!")
@@ -350,7 +352,7 @@ def process_batch(
logger.info(f"Creating attention GIF in {gif_filename}")
# this returns polygons but unused for now.
plot_attention(
image=input_tensor[idx],
image=visu_tensor[idx],
text=predicted_text,
weights=attentions,
level=attention_map_level,
Loading