From 7f25f1373480804a85ae8dd9a25b498d85ad7739 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Tue, 12 Sep 2023 12:03:34 +0200
Subject: [PATCH] Improve documentation

---
 dan/ocr/decoder.py | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py
index 6a158382..59909d14 100644
--- a/dan/ocr/decoder.py
+++ b/dan/ocr/decoder.py
@@ -503,10 +503,9 @@ class CTCLanguageDecoder:
             nbest=1,
         )
         self.temperature = temperature
-
-        self.tokens_to_idx = read_txt(tokens_path).split("\n")
-        self.ctc_id = self.tokens_to_idx.index(blank_token)
         self.space_token = sil_token
+        self.tokens_to_idx = read_txt(tokens_path).split("\n")
+        self.blank_id = self.tokens_to_idx.index(blank_token)
 
     def add_ctc_frames(self, batch_features):
         """
@@ -514,16 +513,19 @@ class CTCLanguageDecoder:
         """
         batch_size, n_frames, n_tokens = batch_features.shape
 
-        # column with 1 probability on CTC token
+        # Create tensor with high probability CTC token
+        high_prob = 0.99
+        low_prob = 0.01
         ctc_probs = (
             torch.ones((batch_size, 1, n_tokens), dtype=torch.float32)
-            * 0.01
+            * low_prob
             / (n_tokens - 1)
         )
-        ctc_probs[:, :, self.ctc_id] = 0.99
+        ctc_probs[:, :, self.blank_id] = high_prob
         ctc_probs = ctc_probs.log()
 
-        for i in range(n_frames - 1):
+        # Insert CTC tensor between frames
+        for i in range(n_frames):
             batch_features = torch.cat(
                 [
                     batch_features[:, : 2 * i + 1, :],
-- 
GitLab