From 80a27658ff2b3b138145d13deb6c30218ccf9436 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Tue, 12 Sep 2023 09:51:57 +0200 Subject: [PATCH] Implement CTCLanguageDecoder --- dan/ocr/decoder.py | 120 ++++++++++++++++++++++++++++++++++ dan/ocr/predict/__init__.py | 6 ++ dan/ocr/predict/prediction.py | 39 ++++++++++- dan/utils.py | 9 +++ requirements.txt | 1 + 5 files changed, 174 insertions(+), 1 deletion(-) diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index ccf2b4d6..8f80341a 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -1,9 +1,13 @@ # -*- coding: utf-8 -*- +import numpy as np import torch from torch import relu, softmax from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList from torch.nn.init import xavier_uniform_ +from torchaudio.models.decoder import ctc_decoder + +from dan.utils import read_txt class PositionalEncoding1D(Module): @@ -459,3 +463,119 @@ class GlobalHTADecoder(Module): ), ) ) + + +class CTCLanguageDecoder: + """ + Initialize a CTC decoder with n-gram language modeling. + Args: + language_model_path (str): path to a KenLM or ARPA language model + lexicon_path (str): path to a lexicon file containing the possible words and corresponding spellings. + Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free + decoding. + tokens_path (str): path to a file containing valid tokens. If using a file, the expected + format is for tokens mapping to the same index to be on the same line + language_model_weight (float): weight of the language model. + blank_token (str): token representing the blank/ctc symbol + unk_token (str): token representing unknown characters + sil_token (str): token representing the space character + """ + + def __init__( + self, + language_model_path: str, + lexicon_path: str, + tokens_path: str, + language_model_weight: float = 1.0, + blank_token: str = "<ctc>", + unk_token: str = "<unk>", + sil_token: str = "<space>", + temperature: float = 1.0, + ): + self.decoder = ctc_decoder( + lm=language_model_path, + lexicon=lexicon_path, + tokens=tokens_path, + lm_weight=language_model_weight, + blank_token=blank_token, + unk_word=unk_token, + sil_token=sil_token, + nbest=1, + ) + self.temperature = temperature + + self.tokens_to_idx = read_txt(tokens_path).split("\n") + self.ctc_id = self.tokens_to_idx.index(blank_token) + self.space_token = sil_token + + def add_ctc_frames(self, batch_features): + """ + Add CTC frames between each characters to avoid duplicate removal + """ + batch_size, n_frames, n_tokens = batch_features.shape + + # column with 1 probability on CTC token + ctc_probs = ( + torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * 0.1 / n_tokens + ) + ctc_probs[:, :, self.ctc_id] = 0.9 + ctc_probs = ctc_probs.log() + + for i in range(n_frames - 1): + batch_features = torch.cat( + [ + batch_features[:, 2 * i + 1 :, :], + ctc_probs, + batch_features[:, : 2 * i + 1, :], + ], + dim=1, + ) + return batch_features + + def post_process(self, hypotheses): + """ + Post-process hypotheses to output JSON + """ + out = {} + # Export only the best hypothesis + out["text"] = [ + "".join(hypothesis[0].words).replace(self.space_token, " ") + for hypothesis in hypotheses + ] + out["confidence"] = [ + np.exp(hypothesis[0].score / hypothesis[0].timesteps[-1].item()) + for hypothesis in hypotheses + ] + return out + + def __call__(self, batch_features, batch_sizes): + """ + Decode a feature vector using n-gram language modelling. + Args: + features (Any): feature vector of size (n_frame, batch_size, n_tokens). + Can be either a torch.tensor or a torch.nn.utils.rnn.PackedSequence + Returns: + out (Dict[str, List]): a dictionary containing the hypothesis (the list of decoded tokens). + There is no character-based probability. + """ + # Reshape from (n_frame, batch_size, n_tokens) to (batch_size, n_frame, n_tokens) + batch_features = batch_features.permute((0, 2, 1)) + + # Apply temperature scaling + batch_features = batch_features / self.temperature + + # Apply log softmax + batch_features = torch.nn.functional.log_softmax(batch_features, dim=-1) + # batch_features = self.add_ctc_frames(batch_features) + # batch_sizes = batch_features.shape[0] + + # No GPU support for torchaudio's ctc_decoder + device = torch.device("cpu") + batch_features = batch_features.to(device) + if isinstance(batch_sizes, list): + batch_sizes = torch.tensor(batch_sizes) + batch_sizes.to(device) + + # Decode + hypotheses = self.decoder(batch_features, batch_sizes) + return self.post_process(hypotheses) diff --git a/dan/ocr/predict/__init__.py b/dan/ocr/predict/__init__.py index 952bcd6a..518b4eb0 100644 --- a/dan/ocr/predict/__init__.py +++ b/dan/ocr/predict/__init__.py @@ -167,4 +167,10 @@ def add_predict_parser(subcommands) -> None: type=str, required=False, ) + parser.add_argument( + "--use-language-model", + help="Whether to use an explicit language model to rescore text hypothesis.", + action="store_true", + required=False, + ) parser.set_defaults(func=run) diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py index 130692f1..e7da4c88 100644 --- a/dan/ocr/predict/prediction.py +++ b/dan/ocr/predict/prediction.py @@ -12,7 +12,7 @@ import numpy as np import torch import yaml -from dan.ocr.decoder import GlobalHTADecoder +from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder from dan.ocr.encoder import FCN_Encoder from dan.ocr.predict.attention import ( Level, @@ -75,6 +75,17 @@ class DAN: decoder = GlobalHTADecoder(parameters["decoder"]).to(self.device) decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=True) + + self.lm_decoder = CTCLanguageDecoder( + language_model_path=parameters["lm_decoder"]["language_model_path"], + lexicon_path=parameters["lm_decoder"]["lexicon_path"], + tokens_path=parameters["lm_decoder"]["tokens_path"], + language_model_weight=parameters["lm_decoder"]["language_model_weight"], + blank_token=parameters["lm_decoder"]["blank_token"], + unk_token=parameters["lm_decoder"]["unk_token"], + sil_token=parameters["lm_decoder"]["sil_token"], + ) + logger.debug(f"Loaded model {model_path}") if mode == "train": @@ -125,6 +136,7 @@ class DAN: threshold_method: str = "otsu", threshold_value: int = 0, max_object_height: int = 50, + use_language_model: bool = False, ) -> dict: """ Run prediction on an input image. @@ -162,6 +174,12 @@ class DAN: (batch_size,), dtype=torch.int, device=self.device ) + tot_pred = torch.zeros( + (batch_size, len(self.charset) + 1, self.max_chars), + dtype=torch.float, + device=self.device, + ) + whole_output = list() confidence_scores = list() attention_maps = list() @@ -192,6 +210,9 @@ class DAN: num_pred=1, ) + # output total logit prediction + tot_pred[:, :, i : i + 1] = pred + pred = pred / self.temperature whole_output.append(output) attention_maps.append(weights) @@ -242,6 +263,8 @@ class DAN: out = {} out["text"] = predicted_text + if use_language_model: + out["language_model"] = self.lm_decoder(tot_pred, predicted_tokens_len) if confidences: out["confidences"] = confidence_scores if attentions: @@ -296,6 +319,7 @@ def process_batch( max_object_height: int, tokens: Dict[str, EntityType], start_token: str, + use_language_model: bool, ) -> None: input_images, visu_images, input_sizes = [], [], [] logger.info("Loading images...") @@ -330,6 +354,7 @@ def process_batch( threshold_value=threshold_value, max_object_height=max_object_height, start_token=start_token, + use_language_model=use_language_model, ) logger.info("Prediction parsing...") @@ -337,6 +362,15 @@ def process_batch( predicted_text = prediction["text"][idx] result = {"text": predicted_text} + # Return LM results + if use_language_model: + result["language_model"] = {} + print(prediction) + result["language_model"]["text"] = prediction["language_model"]["text"][idx] + result["language_model"]["confidence"] = prediction["language_model"][ + "confidence" + ][idx] + # Return extracted objects (coordinates, text, confidence) if predict_objects: result["objects"] = prediction["objects"][idx] @@ -435,6 +469,7 @@ def run( batch_size: int, tokens: Dict[str, EntityType], start_token: str, + use_language_model: bool, ) -> None: """ Predict a single image save the output @@ -458,6 +493,7 @@ def run( :param batch_size: Size of the batches for prediction. :param tokens: NER tokens used. :param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages. + :param use_language_model: Whether to use an explicit language model to rescore text hypothesis. """ # Create output directory if necessary if not output.exists(): @@ -489,4 +525,5 @@ def run( max_object_height, tokens, start_token, + use_language_model, ) diff --git a/dan/utils.py b/dan/utils.py index d825adf7..f813723e 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -137,3 +137,12 @@ def read_json(json_path: str) -> Dict: return json.loads(filename.read_text()) except json.JSONDecodeError as e: raise ArgumentTypeError(e) + + +def read_txt(txt_path: str) -> str: + """ + Read TXT file + """ + filename = Path(txt_path) + assert filename.exists(), f"{txt_path} does not resolve." + return filename.read_text() diff --git a/requirements.txt b/requirements.txt index 99e94e92..d3715a7e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,6 @@ teklia-line-image-extractor==0.2.8rc4 tenacity==8.2.3 tensorboard==2.12.2 torch==2.0.0 +torchaudio==2.0.1 torchvision==0.15.1 tqdm==4.65.0 -- GitLab