Skip to content
Snippets Groups Projects
Commit 1956acc9 authored by Solene Tarride's avatar Solene Tarride Committed by Yoann Schneider
Browse files

Do not normalize tensor for attention map visualization

parent 30613c99
Branches
Tags
1 merge request!221Do not normalize tensor for attention map visualization
......@@ -87,7 +87,7 @@ class DAN:
"""
image = read_image(path)
preprocessed_image = self.preprocessing_transforms(image)
return self.normalization(preprocessed_image)
return preprocessed_image, self.normalization(preprocessed_image)
def predict(
self,
......@@ -264,16 +264,18 @@ def process_batch(
threshold_method,
threshold_value,
):
input_images, input_sizes = [], []
input_images, visu_images, input_sizes = [], [], []
logger.info("Loading images...")
for image_path in image_batch:
# Load image and pre-process it
image = dan_model.preprocess(str(image_path))
input_images.append(image)
input_sizes.append(image.shape[1:])
visu_image, input_image = dan_model.preprocess(str(image_path))
input_images.append(input_image)
visu_images.append(visu_image)
input_sizes.append(input_image.shape[1:])
# Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
input_tensor = pad_images(input_images).to(device)
visu_tensor = pad_images(visu_images).to(device)
logger.info("Images preprocessed!")
......@@ -350,7 +352,7 @@ def process_batch(
logger.info(f"Creating attention GIF in {gif_filename}")
# this returns polygons but unused for now.
plot_attention(
image=input_tensor[idx],
image=visu_tensor[idx],
text=predicted_text,
weights=attentions,
level=attention_map_level,
......
......@@ -46,7 +46,7 @@ def test_predict(
)
image_path = prediction_data_path / "images" / image_name
image = dan_model.preprocess(str(image_path))
_, image = dan_model.preprocess(str(image_path))
input_tensor = image.unsqueeze(0)
input_tensor = input_tensor.to(device)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment