From 2538568cbeb9e0ae692a81903b50d6d25407aa09 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Tue, 12 Sep 2023 09:51:57 +0200
Subject: [PATCH] Implement CTCLanguageDecoder

---
 dan/ocr/decoder.py           | 163 +++++++++++++----------------------
 dan/ocr/predict/__init__.py  |   2 +-
 dan/ocr/predict/inference.py |  15 +++-
 dan/utils.py                 |   4 +-
 4 files changed, 75 insertions(+), 109 deletions(-)

diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py
index b62c0f6c..8f80341a 100644
--- a/dan/ocr/decoder.py
+++ b/dan/ocr/decoder.py
@@ -1,15 +1,13 @@
 # -*- coding: utf-8 -*-
 
-from typing import Dict, List, Union
-
 import numpy as np
 import torch
 from torch import relu, softmax
 from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList
 from torch.nn.init import xavier_uniform_
-from torchaudio.models.decoder import CTCHypothesis, ctc_decoder
+from torchaudio.models.decoder import ctc_decoder
 
-from dan.utils import LMTokenMapping, read_txt
+from dan.utils import read_txt
 
 
 class PositionalEncoding1D(Module):
@@ -470,13 +468,17 @@ class GlobalHTADecoder(Module):
 class CTCLanguageDecoder:
     """
     Initialize a CTC decoder with n-gram language modeling.
-    :param language_model_path: Path to a KenLM or ARPA language model.
-    :param lexicon_path: Path to a lexicon file containing the possible words and corresponding spellings.
-            Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free decoding.
-    :param tokens_path: Path to a file containing valid tokens. If using a file, the expected
-            format is for tokens mapping to the same index to be on the same line.
-    :param language_model_weight: Weight of the language model.
-    :param temperature: Temperature for model calibreation.
+    Args:
+        language_model_path (str): path to a KenLM or ARPA language model
+        lexicon_path (str): path to a lexicon file containing the possible words and corresponding spellings.
+            Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
+            decoding.
+        tokens_path (str): path to a file containing valid tokens. If using a file, the expected
+            format is for tokens mapping to the same index to be on the same line
+        language_model_weight (float): weight of the language model.
+        blank_token (str): token representing the blank/ctc symbol
+        unk_token (str): token representing unknown characters
+        sil_token (str): token representing the space character
     """
 
     def __init__(
@@ -485,138 +487,95 @@ class CTCLanguageDecoder:
         lexicon_path: str,
         tokens_path: str,
         language_model_weight: float = 1.0,
+        blank_token: str = "<ctc>",
+        unk_token: str = "<unk>",
+        sil_token: str = "<space>",
         temperature: float = 1.0,
     ):
-        self.mapping = LMTokenMapping()
-        self.language_model_weight = language_model_weight
-        self.temperature = temperature
-        self.tokens_to_index = {
-            token: i for i, token in enumerate(read_txt(tokens_path).split("\n"))
-        }
-        self.index_to_token = {i: token for token, i in self.tokens_to_index.items()}
-        self.blank_token_id = self.tokens_to_index[self.mapping.ctc.encoded]
-
-        # Torchaudio's decoder
-        # https://pytorch.org/audio/master/generated/torchaudio.models.decoder.ctc_decoder.html
         self.decoder = ctc_decoder(
             lm=language_model_path,
             lexicon=lexicon_path,
             tokens=tokens_path,
-            lm_weight=self.language_model_weight,
-            blank_token=self.mapping.ctc.encoded,
-            sil_token=self.mapping.space.encoded,
-            unk_word="⁇",
+            lm_weight=language_model_weight,
+            blank_token=blank_token,
+            unk_word=unk_token,
+            sil_token=sil_token,
             nbest=1,
         )
-        # No GPU support
-        self.device = torch.device("cpu")
+        self.temperature = temperature
 
-    def add_ctc_frames(
-        self, batch_features: torch.FloatTensor, batch_frames: torch.LongTensor
-    ) -> tuple[torch.FloatTensor, torch.LongTensor]:
+        self.tokens_to_idx = read_txt(tokens_path).split("\n")
+        self.ctc_id = self.tokens_to_idx.index(blank_token)
+        self.space_token = sil_token
+
+    def add_ctc_frames(self, batch_features):
         """
-        Add CTC frames between each characters to avoid duplicate removal.
+        Add CTC frames between each characters to avoid duplicate removal
         """
-        high_prob = batch_features.max()
-        low_prob = batch_features.min()
         batch_size, n_frames, n_tokens = batch_features.shape
-        # Reset probabilities for the CTC token
-        batch_features[:, :, -1] = (
-            torch.ones(
-                (batch_size, n_frames),
-                dtype=torch.float32,
-                device=batch_features.device,
-            )
-            * low_prob
-        )
 
-        # Create a frame with high probability CTC token
+        # column with 1 probability on CTC token
         ctc_probs = (
-            torch.ones(
-                (batch_size, 1, n_tokens),
-                dtype=torch.float32,
-                device=batch_features.device,
-            )
-            * low_prob
+            torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * 0.1 / n_tokens
         )
-        ctc_probs[:, :, self.blank_token_id] = high_prob
-        ctc_probs = ctc_probs
+        ctc_probs[:, :, self.ctc_id] = 0.9
+        ctc_probs = ctc_probs.log()
 
-        # Insert the CTC frame between regular frames
-        for fn in range(batch_frames.max() - 1):
+        for i in range(n_frames - 1):
             batch_features = torch.cat(
                 [
-                    batch_features[:, : 2 * fn + 1, :],
+                    batch_features[:, 2 * i + 1 :, :],
                     ctc_probs,
-                    batch_features[:, 2 * fn + 1 :, :],
+                    batch_features[:, : 2 * i + 1, :],
                 ],
                 dim=1,
             )
+        return batch_features
 
-        # Update the number of frames
-        batch_frames = 2 * batch_frames - 1
-        return batch_features, batch_frames
-
-    def post_process(
-        self, hypotheses: List[CTCHypothesis], batch_sizes: torch.LongTensor
-    ) -> Dict[str, List[Union[str, float]]]:
+    def post_process(self, hypotheses):
         """
-        Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
-        :param hypotheses: List of hypotheses returned by the decoder.
-        :param batch_sizes: Prediction length of size batch_size.
-        :return: A dictionary containing the hypotheses and their confidences.
+        Post-process hypotheses to output JSON
         """
         out = {}
-        # Replace <space> by an actual space and format string
+        # Export only the best hypothesis
         out["text"] = [
-            "".join(
-                [
-                    self.mapping.display[self.index_to_token[token]]
-                    if self.index_to_token[token] in self.mapping.display
-                    else self.index_to_token[token]
-                    for token in hypothesis[0].tokens.tolist()
-                ]
-            ).strip()
+            "".join(hypothesis[0].words).replace(self.space_token, " ")
             for hypothesis in hypotheses
         ]
-        # Normalize confidence score
         out["confidence"] = [
-            np.around(
-                np.exp(
-                    hypothesis[0].score
-                    / ((self.language_model_weight + 1) * length.item())
-                ),
-                2,
-            )
-            for hypothesis, length in zip(hypotheses, batch_sizes)
+            np.exp(hypothesis[0].score / hypothesis[0].timesteps[-1].item())
+            for hypothesis in hypotheses
         ]
         return out
 
-    def __call__(
-        self, batch_features: torch.FloatTensor, batch_frames: torch.LongTensor
-    ) -> Dict[str, List[Union[str, float]]]:
+    def __call__(self, batch_features, batch_sizes):
         """
         Decode a feature vector using n-gram language modelling.
-        :param batch_features: Feature vector of size (batch_size, n_tokens, n_frames).
-        :param batch_frames: Prediction length of size batch_size.
-        :return: A dictionary containing the hypotheses and their confidences.
+        Args:
+            features (Any): feature vector of size (n_frame, batch_size, n_tokens).
+                Can be either a torch.tensor or a torch.nn.utils.rnn.PackedSequence
+        Returns:
+            out (Dict[str, List]): a dictionary containing the hypothesis (the list of decoded tokens).
+                There is no character-based probability.
         """
-        # Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens)
+        # Reshape from (n_frame, batch_size, n_tokens) to (batch_size, n_frame, n_tokens)
         batch_features = batch_features.permute((0, 2, 1))
 
-        # Insert CTC frames to avoid getting rid of duplicates
-        # Make sure that the CTC token has low probs for other frames
-        batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames)
+        # Apply temperature scaling
+        batch_features = batch_features / self.temperature
 
         # Apply log softmax
-        batch_features = torch.nn.functional.log_softmax(
-            batch_features / self.temperature, dim=-1
-        )
+        batch_features = torch.nn.functional.log_softmax(batch_features, dim=-1)
+        # batch_features = self.add_ctc_frames(batch_features)
+        # batch_sizes = batch_features.shape[0]
 
         # No GPU support for torchaudio's ctc_decoder
-        batch_features = batch_features.to(self.device)
-        batch_frames = batch_frames.to(self.device)
+        device = torch.device("cpu")
+        batch_features = batch_features.to(device)
+        if isinstance(batch_sizes, list):
+            batch_sizes = torch.tensor(batch_sizes)
+            batch_sizes.to(device)
 
         # Decode
-        hypotheses = self.decoder(batch_features, batch_frames)
-        return self.post_process(hypotheses, batch_frames)
+        hypotheses = self.decoder(batch_features, batch_sizes)
+        return self.post_process(hypotheses)
diff --git a/dan/ocr/predict/__init__.py b/dan/ocr/predict/__init__.py
index e5f7b2bc..da477c1f 100644
--- a/dan/ocr/predict/__init__.py
+++ b/dan/ocr/predict/__init__.py
@@ -169,7 +169,7 @@ def add_predict_parser(subcommands) -> None:
     )
     parser.add_argument(
         "--use-language-model",
-        help="Whether to use an explicit language model to rescore text hypotheses.",
+        help="Whether to use an explicit language model to rescore text hypothesis.",
         action="store_true",
         required=False,
     )
diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py
index e33d0ba2..a2096015 100644
--- a/dan/ocr/predict/inference.py
+++ b/dan/ocr/predict/inference.py
@@ -77,6 +77,16 @@ class DAN:
         decoder = GlobalHTADecoder(parameters["decoder"]).to(self.device)
         decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=True)
 
+        self.lm_decoder = CTCLanguageDecoder(
+            language_model_path=parameters["lm_decoder"]["language_model_path"],
+            lexicon_path=parameters["lm_decoder"]["lexicon_path"],
+            tokens_path=parameters["lm_decoder"]["tokens_path"],
+            language_model_weight=parameters["lm_decoder"]["language_model_weight"],
+            blank_token=parameters["lm_decoder"]["blank_token"],
+            unk_token=parameters["lm_decoder"]["unk_token"],
+            sil_token=parameters["lm_decoder"]["sil_token"],
+        )
+
         logger.debug(f"Loaded model {model_path}")
 
         if mode == "train":
@@ -179,7 +189,6 @@ class DAN:
                 (batch_size,), dtype=torch.int, device=self.device
             )
 
-            # end token index will be used for ctc
             tot_pred = torch.zeros(
                 (batch_size, len(self.charset) + 1, self.max_chars),
                 dtype=torch.float,
@@ -270,7 +279,7 @@ class DAN:
 
         out["text"] = predicted_text
         if use_language_model:
-            out["language_model"] = self.lm_decoder(tot_pred, prediction_len)
+            out["language_model"] = self.lm_decoder(tot_pred, predicted_tokens_len)
         if confidences:
             out["confidences"] = confidence_scores
         if attentions:
@@ -466,7 +475,7 @@ def run(
     :param batch_size: Size of the batches for prediction.
     :param tokens: NER tokens used.
     :param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.
-    :param use_language_model: Whether to use an explicit language model to rescore text hypotheses.
+    :param use_language_model: Whether to use an explicit language model to rescore text hypothesis.
     """
     # Create output directory if necessary
     if not output.exists():
diff --git a/dan/utils.py b/dan/utils.py
index 176c56c9..69e7d82a 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -163,9 +163,7 @@ def read_json(json_path: str) -> Dict:
 
 def read_txt(txt_path: str) -> str:
     """
-    Read TXT file.
-    :param txt_path: Path of the text file to read.
-    :return: The content of the read file.
+    Read TXT file
     """
     filename = Path(txt_path)
     assert filename.exists(), f"{txt_path} does not resolve."
-- 
GitLab