From 347383a84bff0fa65a532a72fded69e1f187adbf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Mon, 25 Sep 2023 15:05:05 +0200
Subject: [PATCH] Use CTCHypothesis.tokens instead og CTCHypothesis.words

---
 dan/ocr/decoder.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py
index b9b256d1..6dfda6d9 100644
--- a/dan/ocr/decoder.py
+++ b/dan/ocr/decoder.py
@@ -495,6 +495,7 @@ class CTCLanguageDecoder:
         self.tokens_to_index = {
             token: i for i, token in enumerate(read_txt(tokens_path).split("\n"))
         }
+        self.index_to_token = {i: token for token, i in self.tokens_to_index.items()}
         self.blank_token_id = self.tokens_to_index[self.mapping.ctc.encoded]
         self.decoder = ctc_decoder(
             lm=language_model_path,
@@ -516,7 +517,6 @@ class CTCLanguageDecoder:
         high_prob = batch_features.max()
         low_prob = batch_features.min()
         batch_size, n_frames, n_tokens = batch_features.shape
-
         # Reset probabilities for the CTC token
         batch_features[:, :, -1] = (
             torch.ones((batch_size, n_frames), dtype=torch.float32) * low_prob
@@ -553,10 +553,10 @@ class CTCLanguageDecoder:
         out["text"] = [
             "".join(
                 [
-                    self.mapping.display[token]
-                    if token in self.mapping.display
-                    else token
-                    for token in hypothesis[0].words
+                    self.mapping.display[self.index_to_token[token]]
+                    if self.index_to_token[token] in self.mapping.display
+                    else self.index_to_token[token]
+                    for token in hypothesis[0].tokens.tolist()
                 ]
             )
             for hypothesis in hypotheses
-- 
GitLab