From 78a6947160cc79f526b5185aa28a3f11789ac9d1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Fri, 7 Jul 2023 11:37:05 +0000
Subject: [PATCH] Fix input tensor used during prediction

---
 dan/predict/prediction.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index ac029c84..9aa1cd5f 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -283,9 +283,9 @@ def process_image(
     logger.debug("Image pre-processed.")
 
     # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
-    input_tensor = torch.tensor(im_p).permute(2, 0, 1).unsqueeze(0)
+    input_tensor = im_p.unsqueeze(0)
     input_tensor = input_tensor.to(device)
-    input_sizes = [im.shape[:2]]
+    input_sizes = [im_p.shape[1:]]
 
     # Parse delimiters to regex
     word_separators = parse_delimiters(word_separators)
@@ -422,7 +422,7 @@ def run(
     cuda_device = f":{gpu_device}" if gpu_device is not None else ""
     device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
     dan_model = DAN(device, temperature)
-    dan_model.load(model, parameters, charset, mode="train")
+    dan_model.load(model, parameters, charset, mode="eval")
     if image:
         process_image(
             image,
-- 
GitLab