From bffe2c05fbd7d85bf0db34f29aa4ee5c7bb55697 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Sat, 23 Sep 2023 09:17:26 +0200
Subject: [PATCH] Fix CTC token probability

---
 dan/ocr/decoder.py | 26 ++++++++++++++++----------
 1 file changed, 16 insertions(+), 10 deletions(-)

diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py
index 1b8b89d2..b9b256d1 100644
--- a/dan/ocr/decoder.py
+++ b/dan/ocr/decoder.py
@@ -513,20 +513,23 @@ class CTCLanguageDecoder:
         """
         Add CTC frames between each characters to avoid duplicate removal
         """
-        batch_size, _, n_tokens = batch_features.shape
+        high_prob = batch_features.max()
+        low_prob = batch_features.min()
+        batch_size, n_frames, n_tokens = batch_features.shape
 
-        # Create tensor with high probability CTC token
-        high_prob = 0.99
-        low_prob = 1 - high_prob
+        # Reset probabilities for the CTC token
+        batch_features[:, :, -1] = (
+            torch.ones((batch_size, n_frames), dtype=torch.float32) * low_prob
+        )
+
+        # Create a frame with high probability CTC token
         ctc_probs = (
-            torch.ones((batch_size, 1, n_tokens), dtype=torch.float32)
-            * low_prob
-            / (n_tokens - 1)
+            torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * low_prob
         )
         ctc_probs[:, :, self.blank_token_id] = high_prob
-        ctc_probs = ctc_probs.log()
+        ctc_probs = ctc_probs
 
-        # Insert CTC tensor between frames
+        # Insert the CTC frame between regular frames
         for fn in range(batch_frames.max() - 1):
             batch_features = torch.cat(
                 [
@@ -579,11 +582,14 @@ class CTCLanguageDecoder:
         # Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens)
         batch_features = batch_features.permute((0, 2, 1))
 
+        # Insert CTC frames to avoid getting rid of duplicates
+        # Make sure that the CTC token has low probs for other frames
+        batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames)
+
         # Apply log softmax
         batch_features = torch.nn.functional.log_softmax(
             batch_features / self.temperature, dim=-1
         )
-        batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames)
 
         # No GPU support for torchaudio's ctc_decoder
         batch_features = batch_features.to(self.device)
-- 
GitLab