From 1cccd99b715f2d954f68ea6fcb82b3740b50ed33 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Tue, 19 Sep 2023 12:41:40 +0200
Subject: [PATCH] Fix shape of tot_prob

---
 dan/ocr/decoder.py           | 80 +++++++++++++++++++-----------------
 dan/ocr/predict/inference.py | 14 +++++--
 2 files changed, 53 insertions(+), 41 deletions(-)

diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py
index 59909d14..bd1044d7 100644
--- a/dan/ocr/decoder.py
+++ b/dan/ocr/decoder.py
@@ -7,7 +7,7 @@ from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, Modu
 from torch.nn.init import xavier_uniform_
 from torchaudio.models.decoder import ctc_decoder
 
-from dan.utils import read_txt
+from dan.utils import LM_MAPPING, read_txt
 
 
 class PositionalEncoding1D(Module):
@@ -487,56 +487,65 @@ class CTCLanguageDecoder:
         lexicon_path: str,
         tokens_path: str,
         language_model_weight: float = 1.0,
-        blank_token: str = "<ctc>",
-        unk_token: str = "<unk>",
-        sil_token: str = "<space>",
         temperature: float = 1.0,
     ):
+        self.space_token = LM_MAPPING[" "]
+        self.unknown_token = LM_MAPPING["<unk>"]
+        self.blank_token = LM_MAPPING["<ctc>"]
+        self.language_model_weight = language_model_weight
+        self.temperature = temperature
+        self.tokens_to_index = {
+            token: i for i, token in enumerate(read_txt(tokens_path).split("\n"))
+        }
+        self.blank_token_id = self.tokens_to_index[self.blank_token]
         self.decoder = ctc_decoder(
             lm=language_model_path,
             lexicon=lexicon_path,
             tokens=tokens_path,
-            lm_weight=language_model_weight,
-            blank_token=blank_token,
-            unk_word=unk_token,
-            sil_token=sil_token,
+            lm_weight=self.language_model_weight,
+            blank_token=self.blank_token,
+            unk_word=self.unknown_token,
+            sil_token=self.space_token,
             nbest=1,
         )
-        self.temperature = temperature
-        self.space_token = sil_token
-        self.tokens_to_idx = read_txt(tokens_path).split("\n")
-        self.blank_id = self.tokens_to_idx.index(blank_token)
+        # No GPU support
+        self.device = torch.device("cpu")
 
-    def add_ctc_frames(self, batch_features):
+    def add_ctc_frames(self, batch_features, batch_frames):
         """
         Add CTC frames between each characters to avoid duplicate removal
         """
-        batch_size, n_frames, n_tokens = batch_features.shape
+        batch_size, _, n_tokens = batch_features.shape
+        torch.clone(batch_features)
+        # visualize_debug(batch_features.exp()[0, :batch_frames[0], :].numpy(), "probs.jpg", False)
 
         # Create tensor with high probability CTC token
         high_prob = 0.99
-        low_prob = 0.01
+        low_prob = 1 - high_prob
         ctc_probs = (
             torch.ones((batch_size, 1, n_tokens), dtype=torch.float32)
             * low_prob
             / (n_tokens - 1)
         )
-        ctc_probs[:, :, self.blank_id] = high_prob
+        ctc_probs[:, :, self.blank_token_id] = high_prob
         ctc_probs = ctc_probs.log()
 
         # Insert CTC tensor between frames
-        for i in range(n_frames):
+        for fn in range(batch_frames[0] - 1):
             batch_features = torch.cat(
                 [
-                    batch_features[:, : 2 * i + 1, :],
+                    batch_features[:, : 2 * fn + 1, :],
                     ctc_probs,
-                    batch_features[:, 2 * i + 1 :, :],
+                    batch_features[:, 2 * fn + 1 :, :],
                 ],
                 dim=1,
             )
-        return batch_features
 
-    def post_process(self, hypotheses):
+        # Update the number of frames
+        batch_frames = 2 * batch_frames - 1
+        return batch_features, batch_frames
+
+    def post_process(self, hypotheses, batch_sizes):
         """
         Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
         """
@@ -548,12 +557,14 @@ class CTCLanguageDecoder:
         ]
         # Normalize confidence score
         out["confidence"] = [
-            np.exp(hypothesis[0].score / hypothesis[0].timesteps[-1].item())
-            for hypothesis in hypotheses
+            np.exp(
+                hypothesis[0].score / ((self.language_model_weight + 1) * length.item())
+            )
+            for hypothesis, length in zip(hypotheses, batch_sizes)
         ]
         return out
 
-    def __call__(self, batch_features, batch_sizes):
+    def __call__(self, batch_features, batch_frames):
         """
         Decode a feature vector using n-gram language modelling.
         Args:
@@ -565,21 +576,16 @@ class CTCLanguageDecoder:
         # Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens)
         batch_features = batch_features.permute((0, 2, 1))
 
-        # Apply temperature scaling
-        batch_features = batch_features / self.temperature
-
         # Apply log softmax
-        batch_features = torch.nn.functional.log_softmax(batch_features, dim=-1)
-        batch_features = self.add_ctc_frames(batch_features)
-        batch_sizes *= 2
+        batch_features = torch.nn.functional.log_softmax(
+            batch_features / self.temperature, dim=-1
+        )
+        batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames)
 
         # No GPU support for torchaudio's ctc_decoder
-        device = torch.device("cpu")
-        batch_features = batch_features.to(device)
-        if isinstance(batch_sizes, list):
-            batch_sizes = torch.tensor(batch_sizes)
-            batch_sizes.to(device)
+        batch_features = batch_features.to(self.device)
+        batch_frames = batch_frames.to(self.device)
 
         # Decode
-        hypotheses = self.decoder(batch_features, batch_sizes)
-        return self.post_process(hypotheses)
+        hypotheses = self.decoder(batch_features, batch_frames)
+        return self.post_process(hypotheses, batch_frames)
diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py
index 4153633b..b04995d0 100644
--- a/dan/ocr/predict/inference.py
+++ b/dan/ocr/predict/inference.py
@@ -90,6 +90,7 @@ class DAN:
 
         self.encoder = encoder
         self.decoder = decoder
+        self.lm_decoder = None
 
         if use_language_model:
             self.lm_decoder = CTCLanguageDecoder(
@@ -97,9 +98,6 @@ class DAN:
                 lexicon_path=parameters["lm_decoder"]["lexicon_path"],
                 tokens_path=parameters["lm_decoder"]["tokens_path"],
                 language_model_weight=parameters["lm_decoder"]["language_model_weight"],
-                blank_token=parameters["lm_decoder"]["blank_token"],
-                unk_token=parameters["lm_decoder"]["unk_token"],
-                sil_token=parameters["lm_decoder"]["sil_token"],
             )
 
         self.mean, self.std = (
@@ -178,6 +176,7 @@ class DAN:
                 (batch_size,), dtype=torch.int, device=self.device
             )
 
+            # end token index will be used for ctc
             tot_pred = torch.zeros(
                 (batch_size, len(self.charset) + 1, self.max_chars),
                 dtype=torch.float,
@@ -268,7 +267,7 @@ class DAN:
 
         out["text"] = predicted_text
         if use_language_model:
-            out["language_model"] = self.lm_decoder(tot_pred, predicted_tokens_len)
+            out["language_model"] = self.lm_decoder(tot_pred, prediction_len)
         if confidences:
             out["confidences"] = confidence_scores
         if attentions:
@@ -474,7 +473,14 @@ def run(
     cuda_device = f":{gpu_device}" if gpu_device is not None else ""
     device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
     dan_model = DAN(device, temperature)
+<<<<<<< HEAD
     dan_model.load(model, parameters, charset, mode="eval")
+=======
+    dan_model.load(
+        model, parameters, charset, mode="eval", use_language_model=use_language_model
+    )
+    batch_size = 1 if use_language_model else batch_size
+>>>>>>> e7c611f (Fix shape of tot_prob)
 
     images = image_dir.rglob(f"*{image_extension}") if not image else [image]
     for image_batch in list_to_batches(images, n=batch_size):
-- 
GitLab