From e09a5ec431502e23c226197ac3f2abe12deadc88 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Mon, 25 Sep 2023 16:40:29 +0200
Subject: [PATCH] Move tensor to correct device and trim prediction

---
 dan/ocr/decoder.py | 16 +++++++++++++---
 1 file changed, 13 insertions(+), 3 deletions(-)

diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py
index 6dfda6d9..8279d205 100644
--- a/dan/ocr/decoder.py
+++ b/dan/ocr/decoder.py
@@ -519,12 +519,22 @@ class CTCLanguageDecoder:
         batch_size, n_frames, n_tokens = batch_features.shape
         # Reset probabilities for the CTC token
         batch_features[:, :, -1] = (
-            torch.ones((batch_size, n_frames), dtype=torch.float32) * low_prob
+            torch.ones(
+                (batch_size, n_frames),
+                dtype=torch.float32,
+                device=batch_features.device,
+            )
+            * low_prob
         )
 
         # Create a frame with high probability CTC token
         ctc_probs = (
-            torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * low_prob
+            torch.ones(
+                (batch_size, 1, n_tokens),
+                dtype=torch.float32,
+                device=batch_features.device,
+            )
+            * low_prob
         )
         ctc_probs[:, :, self.blank_token_id] = high_prob
         ctc_probs = ctc_probs
@@ -558,7 +568,7 @@ class CTCLanguageDecoder:
                     else self.index_to_token[token]
                     for token in hypothesis[0].tokens.tolist()
                 ]
-            )
+            ).strip()
             for hypothesis in hypotheses
         ]
         # Normalize confidence score
-- 
GitLab