From 80a27658ff2b3b138145d13deb6c30218ccf9436 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Tue, 12 Sep 2023 09:51:57 +0200
Subject: [PATCH] Implement CTCLanguageDecoder

---
 dan/ocr/decoder.py            | 120 ++++++++++++++++++++++++++++++++++
 dan/ocr/predict/__init__.py   |   6 ++
 dan/ocr/predict/prediction.py |  39 ++++++++++-
 dan/utils.py                  |   9 +++
 requirements.txt              |   1 +
 5 files changed, 174 insertions(+), 1 deletion(-)

diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py
index ccf2b4d6..8f80341a 100644
--- a/dan/ocr/decoder.py
+++ b/dan/ocr/decoder.py
@@ -1,9 +1,13 @@
 # -*- coding: utf-8 -*-
 
+import numpy as np
 import torch
 from torch import relu, softmax
 from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList
 from torch.nn.init import xavier_uniform_
+from torchaudio.models.decoder import ctc_decoder
+
+from dan.utils import read_txt
 
 
 class PositionalEncoding1D(Module):
@@ -459,3 +463,119 @@ class GlobalHTADecoder(Module):
                     ),
                 )
             )
+
+
+class CTCLanguageDecoder:
+    """
+    Initialize a CTC decoder with n-gram language modeling.
+    Args:
+        language_model_path (str): path to a KenLM or ARPA language model
+        lexicon_path (str): path to a lexicon file containing the possible words and corresponding spellings.
+            Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
+            decoding.
+        tokens_path (str): path to a file containing valid tokens. If using a file, the expected
+            format is for tokens mapping to the same index to be on the same line
+        language_model_weight (float): weight of the language model.
+        blank_token (str): token representing the blank/ctc symbol
+        unk_token (str): token representing unknown characters
+        sil_token (str): token representing the space character
+    """
+
+    def __init__(
+        self,
+        language_model_path: str,
+        lexicon_path: str,
+        tokens_path: str,
+        language_model_weight: float = 1.0,
+        blank_token: str = "<ctc>",
+        unk_token: str = "<unk>",
+        sil_token: str = "<space>",
+        temperature: float = 1.0,
+    ):
+        self.decoder = ctc_decoder(
+            lm=language_model_path,
+            lexicon=lexicon_path,
+            tokens=tokens_path,
+            lm_weight=language_model_weight,
+            blank_token=blank_token,
+            unk_word=unk_token,
+            sil_token=sil_token,
+            nbest=1,
+        )
+        self.temperature = temperature
+
+        self.tokens_to_idx = read_txt(tokens_path).split("\n")
+        self.ctc_id = self.tokens_to_idx.index(blank_token)
+        self.space_token = sil_token
+
+    def add_ctc_frames(self, batch_features):
+        """
+        Add CTC frames between each characters to avoid duplicate removal
+        """
+        batch_size, n_frames, n_tokens = batch_features.shape
+
+        # column with 1 probability on CTC token
+        ctc_probs = (
+            torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * 0.1 / n_tokens
+        )
+        ctc_probs[:, :, self.ctc_id] = 0.9
+        ctc_probs = ctc_probs.log()
+
+        for i in range(n_frames - 1):
+            batch_features = torch.cat(
+                [
+                    batch_features[:, 2 * i + 1 :, :],
+                    ctc_probs,
+                    batch_features[:, : 2 * i + 1, :],
+                ],
+                dim=1,
+            )
+        return batch_features
+
+    def post_process(self, hypotheses):
+        """
+        Post-process hypotheses to output JSON
+        """
+        out = {}
+        # Export only the best hypothesis
+        out["text"] = [
+            "".join(hypothesis[0].words).replace(self.space_token, " ")
+            for hypothesis in hypotheses
+        ]
+        out["confidence"] = [
+            np.exp(hypothesis[0].score / hypothesis[0].timesteps[-1].item())
+            for hypothesis in hypotheses
+        ]
+        return out
+
+    def __call__(self, batch_features, batch_sizes):
+        """
+        Decode a feature vector using n-gram language modelling.
+        Args:
+            features (Any): feature vector of size (n_frame, batch_size, n_tokens).
+                Can be either a torch.tensor or a torch.nn.utils.rnn.PackedSequence
+        Returns:
+            out (Dict[str, List]): a dictionary containing the hypothesis (the list of decoded tokens).
+                There is no character-based probability.
+        """
+        # Reshape from (n_frame, batch_size, n_tokens) to (batch_size, n_frame, n_tokens)
+        batch_features = batch_features.permute((0, 2, 1))
+
+        # Apply temperature scaling
+        batch_features = batch_features / self.temperature
+
+        # Apply log softmax
+        batch_features = torch.nn.functional.log_softmax(batch_features, dim=-1)
+        # batch_features = self.add_ctc_frames(batch_features)
+        # batch_sizes = batch_features.shape[0]
+
+        # No GPU support for torchaudio's ctc_decoder
+        device = torch.device("cpu")
+        batch_features = batch_features.to(device)
+        if isinstance(batch_sizes, list):
+            batch_sizes = torch.tensor(batch_sizes)
+            batch_sizes.to(device)
+
+        # Decode
+        hypotheses = self.decoder(batch_features, batch_sizes)
+        return self.post_process(hypotheses)
diff --git a/dan/ocr/predict/__init__.py b/dan/ocr/predict/__init__.py
index 952bcd6a..518b4eb0 100644
--- a/dan/ocr/predict/__init__.py
+++ b/dan/ocr/predict/__init__.py
@@ -167,4 +167,10 @@ def add_predict_parser(subcommands) -> None:
         type=str,
         required=False,
     )
+    parser.add_argument(
+        "--use-language-model",
+        help="Whether to use an explicit language model to rescore text hypothesis.",
+        action="store_true",
+        required=False,
+    )
     parser.set_defaults(func=run)
diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py
index 130692f1..e7da4c88 100644
--- a/dan/ocr/predict/prediction.py
+++ b/dan/ocr/predict/prediction.py
@@ -12,7 +12,7 @@ import numpy as np
 import torch
 import yaml
 
-from dan.ocr.decoder import GlobalHTADecoder
+from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder
 from dan.ocr.encoder import FCN_Encoder
 from dan.ocr.predict.attention import (
     Level,
@@ -75,6 +75,17 @@ class DAN:
 
         decoder = GlobalHTADecoder(parameters["decoder"]).to(self.device)
         decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=True)
+
+        self.lm_decoder = CTCLanguageDecoder(
+            language_model_path=parameters["lm_decoder"]["language_model_path"],
+            lexicon_path=parameters["lm_decoder"]["lexicon_path"],
+            tokens_path=parameters["lm_decoder"]["tokens_path"],
+            language_model_weight=parameters["lm_decoder"]["language_model_weight"],
+            blank_token=parameters["lm_decoder"]["blank_token"],
+            unk_token=parameters["lm_decoder"]["unk_token"],
+            sil_token=parameters["lm_decoder"]["sil_token"],
+        )
+
         logger.debug(f"Loaded model {model_path}")
 
         if mode == "train":
@@ -125,6 +136,7 @@ class DAN:
         threshold_method: str = "otsu",
         threshold_value: int = 0,
         max_object_height: int = 50,
+        use_language_model: bool = False,
     ) -> dict:
         """
         Run prediction on an input image.
@@ -162,6 +174,12 @@ class DAN:
                 (batch_size,), dtype=torch.int, device=self.device
             )
 
+            tot_pred = torch.zeros(
+                (batch_size, len(self.charset) + 1, self.max_chars),
+                dtype=torch.float,
+                device=self.device,
+            )
+
             whole_output = list()
             confidence_scores = list()
             attention_maps = list()
@@ -192,6 +210,9 @@ class DAN:
                     num_pred=1,
                 )
 
+                # output total logit prediction
+                tot_pred[:, :, i : i + 1] = pred
+
                 pred = pred / self.temperature
                 whole_output.append(output)
                 attention_maps.append(weights)
@@ -242,6 +263,8 @@ class DAN:
         out = {}
 
         out["text"] = predicted_text
+        if use_language_model:
+            out["language_model"] = self.lm_decoder(tot_pred, predicted_tokens_len)
         if confidences:
             out["confidences"] = confidence_scores
         if attentions:
@@ -296,6 +319,7 @@ def process_batch(
     max_object_height: int,
     tokens: Dict[str, EntityType],
     start_token: str,
+    use_language_model: bool,
 ) -> None:
     input_images, visu_images, input_sizes = [], [], []
     logger.info("Loading images...")
@@ -330,6 +354,7 @@ def process_batch(
         threshold_value=threshold_value,
         max_object_height=max_object_height,
         start_token=start_token,
+        use_language_model=use_language_model,
     )
 
     logger.info("Prediction parsing...")
@@ -337,6 +362,15 @@ def process_batch(
         predicted_text = prediction["text"][idx]
         result = {"text": predicted_text}
 
+        # Return LM results
+        if use_language_model:
+            result["language_model"] = {}
+            print(prediction)
+            result["language_model"]["text"] = prediction["language_model"]["text"][idx]
+            result["language_model"]["confidence"] = prediction["language_model"][
+                "confidence"
+            ][idx]
+
         # Return extracted objects (coordinates, text, confidence)
         if predict_objects:
             result["objects"] = prediction["objects"][idx]
@@ -435,6 +469,7 @@ def run(
     batch_size: int,
     tokens: Dict[str, EntityType],
     start_token: str,
+    use_language_model: bool,
 ) -> None:
     """
     Predict a single image save the output
@@ -458,6 +493,7 @@ def run(
     :param batch_size: Size of the batches for prediction.
     :param tokens: NER tokens used.
     :param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.
+    :param use_language_model: Whether to use an explicit language model to rescore text hypothesis.
     """
     # Create output directory if necessary
     if not output.exists():
@@ -489,4 +525,5 @@ def run(
             max_object_height,
             tokens,
             start_token,
+            use_language_model,
         )
diff --git a/dan/utils.py b/dan/utils.py
index d825adf7..f813723e 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -137,3 +137,12 @@ def read_json(json_path: str) -> Dict:
         return json.loads(filename.read_text())
     except json.JSONDecodeError as e:
         raise ArgumentTypeError(e)
+
+
+def read_txt(txt_path: str) -> str:
+    """
+    Read TXT file
+    """
+    filename = Path(txt_path)
+    assert filename.exists(), f"{txt_path} does not resolve."
+    return filename.read_text()
diff --git a/requirements.txt b/requirements.txt
index 99e94e92..d3715a7e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,5 +13,6 @@ teklia-line-image-extractor==0.2.8rc4
 tenacity==8.2.3
 tensorboard==2.12.2
 torch==2.0.0
+torchaudio==2.0.1
 torchvision==0.15.1
 tqdm==4.65.0
-- 
GitLab