diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index 9104aa8f5b9a1a0d7a9f6b499d1cc201dd80fce7..f2e2ee168089ad322c46f9f99db67c09258e1a6e 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -90,11 +90,12 @@ class DAN: """ image = read_image(path) preprocessed_image = self.preprocessing_transforms(image) + normalized_image = torch.zeros(preprocessed_image.shape) for ch in range(preprocessed_image.shape[0]): - preprocessed_image[ch, :, :] = ( + normalized_image[ch, :, :] = ( preprocessed_image[ch, :, :] - self.mean[ch] ) / self.std[ch] - return preprocessed_image + return preprocessed_image, normalized_image def predict( self, @@ -271,16 +272,18 @@ def process_batch( threshold_method, threshold_value, ): - input_images, input_sizes = [], [] + input_images, visu_images, input_sizes = [], [], [] logger.info("Loading images...") for image_path in image_batch: # Load image and pre-process it - image = dan_model.preprocess(str(image_path)) - input_images.append(image) - input_sizes.append(image.shape[1:]) + visu_image, input_image = dan_model.preprocess(str(image_path)) + input_images.append(input_image) + visu_images.append(visu_image) + input_sizes.append(input_image.shape[1:]) # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1 input_tensor = pad_images(input_images).to(device) + visu_tensor = pad_images(visu_images).to(device) logger.info("Images preprocessed!") # Parse delimiters to regex @@ -355,7 +358,7 @@ def process_batch( logger.info(f"Creating attention GIF in {gif_filename}") # this returns polygons but unused for now. plot_attention( - image=input_tensor[idx], + image=visu_tensor[idx], text=predicted_text, weights=attentions, level=attention_map_level, diff --git a/tests/test_prediction.py b/tests/test_prediction.py index fad5d04d3482fa5adecf8d56226f6d7b60bd0c36..58972b1b48d43c0fcaa824f55c58710fdaee3766 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -46,7 +46,7 @@ def test_predict( ) image_path = prediction_data_path / "images" / image_name - image = dan_model.preprocess(str(image_path)) + _, image = dan_model.preprocess(str(image_path)) input_tensor = image.unsqueeze(0) input_tensor = input_tensor.to(device)