From dfd6835bfa6f119adfca3ecb96645ee2ff230381 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com> Date: Fri, 7 Jul 2023 11:37:05 +0000 Subject: [PATCH] Fix input tensor used during prediction --- dan/predict/prediction.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index 3b9fa98c..3726679e 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -282,9 +282,9 @@ def process_image( logger.debug("Image pre-processed.") # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1 - input_tensor = torch.tensor(im_p).permute(2, 0, 1).unsqueeze(0) + input_tensor = im_p.unsqueeze(0) input_tensor = input_tensor.to(device) - input_sizes = [im.shape[:2]] + input_sizes = [im_p.shape[1:]] # Parse delimiters to regex word_separators = parse_delimiters(word_separators) @@ -421,7 +421,7 @@ def run( cuda_device = f":{gpu_device}" if gpu_device is not None else "" device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" dan_model = DAN(device, temperature) - dan_model.load(model, parameters, charset, mode="train") + dan_model.load(model, parameters, charset, mode="eval") if image: process_image( image, -- GitLab