From 78a6947160cc79f526b5185aa28a3f11789ac9d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com> Date: Fri, 7 Jul 2023 11:37:05 +0000 Subject: [PATCH] Fix input tensor used during prediction --- dan/predict/prediction.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index ac029c84..9aa1cd5f 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -283,9 +283,9 @@ def process_image( logger.debug("Image pre-processed.") # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1 - input_tensor = torch.tensor(im_p).permute(2, 0, 1).unsqueeze(0) + input_tensor = im_p.unsqueeze(0) input_tensor = input_tensor.to(device) - input_sizes = [im.shape[:2]] + input_sizes = [im_p.shape[1:]] # Parse delimiters to regex word_separators = parse_delimiters(word_separators) @@ -422,7 +422,7 @@ def run( cuda_device = f":{gpu_device}" if gpu_device is not None else "" device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" dan_model = DAN(device, temperature) - dan_model.load(model, parameters, charset, mode="train") + dan_model.load(model, parameters, charset, mode="eval") if image: process_image( image, -- GitLab