Skip to content
Snippets Groups Projects
Verified Commit bac1de3c authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Apply 1956acc9

parent ee994ff3
No related branches found
No related tags found
No related merge requests found
...@@ -90,11 +90,12 @@ class DAN: ...@@ -90,11 +90,12 @@ class DAN:
""" """
image = read_image(path) image = read_image(path)
preprocessed_image = self.preprocessing_transforms(image) preprocessed_image = self.preprocessing_transforms(image)
normalized_image = torch.zeros(preprocessed_image.shape)
for ch in range(preprocessed_image.shape[0]): for ch in range(preprocessed_image.shape[0]):
preprocessed_image[ch, :, :] = ( normalized_image[ch, :, :] = (
preprocessed_image[ch, :, :] - self.mean[ch] preprocessed_image[ch, :, :] - self.mean[ch]
) / self.std[ch] ) / self.std[ch]
return preprocessed_image return preprocessed_image, normalized_image
def predict( def predict(
self, self,
...@@ -271,16 +272,18 @@ def process_batch( ...@@ -271,16 +272,18 @@ def process_batch(
threshold_method, threshold_method,
threshold_value, threshold_value,
): ):
input_images, input_sizes = [], [] input_images, visu_images, input_sizes = [], [], []
logger.info("Loading images...") logger.info("Loading images...")
for image_path in image_batch: for image_path in image_batch:
# Load image and pre-process it # Load image and pre-process it
image = dan_model.preprocess(str(image_path)) visu_image, input_image = dan_model.preprocess(str(image_path))
input_images.append(image) input_images.append(input_image)
input_sizes.append(image.shape[1:]) visu_images.append(visu_image)
input_sizes.append(input_image.shape[1:])
# Convert to tensor of size (batch_size, channel, height, width) with batch_size=1 # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
input_tensor = pad_images(input_images).to(device) input_tensor = pad_images(input_images).to(device)
visu_tensor = pad_images(visu_images).to(device)
logger.info("Images preprocessed!") logger.info("Images preprocessed!")
# Parse delimiters to regex # Parse delimiters to regex
...@@ -355,7 +358,7 @@ def process_batch( ...@@ -355,7 +358,7 @@ def process_batch(
logger.info(f"Creating attention GIF in {gif_filename}") logger.info(f"Creating attention GIF in {gif_filename}")
# this returns polygons but unused for now. # this returns polygons but unused for now.
plot_attention( plot_attention(
image=input_tensor[idx], image=visu_tensor[idx],
text=predicted_text, text=predicted_text,
weights=attentions, weights=attentions,
level=attention_map_level, level=attention_map_level,
......
...@@ -46,7 +46,7 @@ def test_predict( ...@@ -46,7 +46,7 @@ def test_predict(
) )
image_path = prediction_data_path / "images" / image_name image_path = prediction_data_path / "images" / image_name
image = dan_model.preprocess(str(image_path)) _, image = dan_model.preprocess(str(image_path))
input_tensor = image.unsqueeze(0) input_tensor = image.unsqueeze(0)
input_tensor = input_tensor.to(device) input_tensor = input_tensor.to(device)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment