From 80eb95e43ed2ee943fd43e412e8ce1b47b59d15b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Tue, 12 Sep 2023 11:26:01 +0200
Subject: [PATCH] Add CTC frame between each frames

---
 dan/ocr/decoder.py | 28 +++++++++++++++-------------
 1 file changed, 15 insertions(+), 13 deletions(-)

diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py
index 8f80341a..6a158382 100644
--- a/dan/ocr/decoder.py
+++ b/dan/ocr/decoder.py
@@ -516,17 +516,19 @@ class CTCLanguageDecoder:
 
         # column with 1 probability on CTC token
         ctc_probs = (
-            torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * 0.1 / n_tokens
+            torch.ones((batch_size, 1, n_tokens), dtype=torch.float32)
+            * 0.01
+            / (n_tokens - 1)
         )
-        ctc_probs[:, :, self.ctc_id] = 0.9
+        ctc_probs[:, :, self.ctc_id] = 0.99
         ctc_probs = ctc_probs.log()
 
         for i in range(n_frames - 1):
             batch_features = torch.cat(
                 [
-                    batch_features[:, 2 * i + 1 :, :],
-                    ctc_probs,
                     batch_features[:, : 2 * i + 1, :],
+                    ctc_probs,
+                    batch_features[:, 2 * i + 1 :, :],
                 ],
                 dim=1,
             )
@@ -534,14 +536,15 @@ class CTCLanguageDecoder:
 
     def post_process(self, hypotheses):
         """
-        Post-process hypotheses to output JSON
+        Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
         """
         out = {}
-        # Export only the best hypothesis
+        # Replace <space> by an actual space and format string
         out["text"] = [
             "".join(hypothesis[0].words).replace(self.space_token, " ")
             for hypothesis in hypotheses
         ]
+        # Normalize confidence score
         out["confidence"] = [
             np.exp(hypothesis[0].score / hypothesis[0].timesteps[-1].item())
             for hypothesis in hypotheses
@@ -552,13 +555,12 @@ class CTCLanguageDecoder:
         """
         Decode a feature vector using n-gram language modelling.
         Args:
-            features (Any): feature vector of size (n_frame, batch_size, n_tokens).
-                Can be either a torch.tensor or a torch.nn.utils.rnn.PackedSequence
+            features (torch.tensor): feature vector of size (batch_size, n_tokens, n_frame).
+            batch_sizes (Union[List, torch.tensor]): actual length of predictions
         Returns:
-            out (Dict[str, List]): a dictionary containing the hypothesis (the list of decoded tokens).
-                There is no character-based probability.
+            out (Dict[List]): a dictionary containing the hypotheses.
         """
-        # Reshape from (n_frame, batch_size, n_tokens) to (batch_size, n_frame, n_tokens)
+        # Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens)
         batch_features = batch_features.permute((0, 2, 1))
 
         # Apply temperature scaling
@@ -566,8 +568,8 @@ class CTCLanguageDecoder:
 
         # Apply log softmax
         batch_features = torch.nn.functional.log_softmax(batch_features, dim=-1)
-        # batch_features = self.add_ctc_frames(batch_features)
-        # batch_sizes = batch_features.shape[0]
+        batch_features = self.add_ctc_frames(batch_features)
+        batch_sizes *= 2
 
         # No GPU support for torchaudio's ctc_decoder
         device = torch.device("cpu")
-- 
GitLab