From 2538568cbeb9e0ae692a81903b50d6d25407aa09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Tue, 12 Sep 2023 09:51:57 +0200 Subject: [PATCH] Implement CTCLanguageDecoder --- dan/ocr/decoder.py | 163 +++++++++++++---------------------- dan/ocr/predict/__init__.py | 2 +- dan/ocr/predict/inference.py | 15 +++- dan/utils.py | 4 +- 4 files changed, 75 insertions(+), 109 deletions(-) diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index b62c0f6c..8f80341a 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -1,15 +1,13 @@ # -*- coding: utf-8 -*- -from typing import Dict, List, Union - import numpy as np import torch from torch import relu, softmax from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList from torch.nn.init import xavier_uniform_ -from torchaudio.models.decoder import CTCHypothesis, ctc_decoder +from torchaudio.models.decoder import ctc_decoder -from dan.utils import LMTokenMapping, read_txt +from dan.utils import read_txt class PositionalEncoding1D(Module): @@ -470,13 +468,17 @@ class GlobalHTADecoder(Module): class CTCLanguageDecoder: """ Initialize a CTC decoder with n-gram language modeling. - :param language_model_path: Path to a KenLM or ARPA language model. - :param lexicon_path: Path to a lexicon file containing the possible words and corresponding spellings. - Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free decoding. - :param tokens_path: Path to a file containing valid tokens. If using a file, the expected - format is for tokens mapping to the same index to be on the same line. - :param language_model_weight: Weight of the language model. - :param temperature: Temperature for model calibreation. + Args: + language_model_path (str): path to a KenLM or ARPA language model + lexicon_path (str): path to a lexicon file containing the possible words and corresponding spellings. + Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free + decoding. + tokens_path (str): path to a file containing valid tokens. If using a file, the expected + format is for tokens mapping to the same index to be on the same line + language_model_weight (float): weight of the language model. + blank_token (str): token representing the blank/ctc symbol + unk_token (str): token representing unknown characters + sil_token (str): token representing the space character """ def __init__( @@ -485,138 +487,95 @@ class CTCLanguageDecoder: lexicon_path: str, tokens_path: str, language_model_weight: float = 1.0, + blank_token: str = "<ctc>", + unk_token: str = "<unk>", + sil_token: str = "<space>", temperature: float = 1.0, ): - self.mapping = LMTokenMapping() - self.language_model_weight = language_model_weight - self.temperature = temperature - self.tokens_to_index = { - token: i for i, token in enumerate(read_txt(tokens_path).split("\n")) - } - self.index_to_token = {i: token for token, i in self.tokens_to_index.items()} - self.blank_token_id = self.tokens_to_index[self.mapping.ctc.encoded] - - # Torchaudio's decoder - # https://pytorch.org/audio/master/generated/torchaudio.models.decoder.ctc_decoder.html self.decoder = ctc_decoder( lm=language_model_path, lexicon=lexicon_path, tokens=tokens_path, - lm_weight=self.language_model_weight, - blank_token=self.mapping.ctc.encoded, - sil_token=self.mapping.space.encoded, - unk_word="â‡", + lm_weight=language_model_weight, + blank_token=blank_token, + unk_word=unk_token, + sil_token=sil_token, nbest=1, ) - # No GPU support - self.device = torch.device("cpu") + self.temperature = temperature - def add_ctc_frames( - self, batch_features: torch.FloatTensor, batch_frames: torch.LongTensor - ) -> tuple[torch.FloatTensor, torch.LongTensor]: + self.tokens_to_idx = read_txt(tokens_path).split("\n") + self.ctc_id = self.tokens_to_idx.index(blank_token) + self.space_token = sil_token + + def add_ctc_frames(self, batch_features): """ - Add CTC frames between each characters to avoid duplicate removal. + Add CTC frames between each characters to avoid duplicate removal """ - high_prob = batch_features.max() - low_prob = batch_features.min() batch_size, n_frames, n_tokens = batch_features.shape - # Reset probabilities for the CTC token - batch_features[:, :, -1] = ( - torch.ones( - (batch_size, n_frames), - dtype=torch.float32, - device=batch_features.device, - ) - * low_prob - ) - # Create a frame with high probability CTC token + # column with 1 probability on CTC token ctc_probs = ( - torch.ones( - (batch_size, 1, n_tokens), - dtype=torch.float32, - device=batch_features.device, - ) - * low_prob + torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * 0.1 / n_tokens ) - ctc_probs[:, :, self.blank_token_id] = high_prob - ctc_probs = ctc_probs + ctc_probs[:, :, self.ctc_id] = 0.9 + ctc_probs = ctc_probs.log() - # Insert the CTC frame between regular frames - for fn in range(batch_frames.max() - 1): + for i in range(n_frames - 1): batch_features = torch.cat( [ - batch_features[:, : 2 * fn + 1, :], + batch_features[:, 2 * i + 1 :, :], ctc_probs, - batch_features[:, 2 * fn + 1 :, :], + batch_features[:, : 2 * i + 1, :], ], dim=1, ) + return batch_features - # Update the number of frames - batch_frames = 2 * batch_frames - 1 - return batch_features, batch_frames - - def post_process( - self, hypotheses: List[CTCHypothesis], batch_sizes: torch.LongTensor - ) -> Dict[str, List[Union[str, float]]]: + def post_process(self, hypotheses): """ - Post-process hypotheses to output JSON. Exports only the best hypothesis for each image. - :param hypotheses: List of hypotheses returned by the decoder. - :param batch_sizes: Prediction length of size batch_size. - :return: A dictionary containing the hypotheses and their confidences. + Post-process hypotheses to output JSON """ out = {} - # Replace <space> by an actual space and format string + # Export only the best hypothesis out["text"] = [ - "".join( - [ - self.mapping.display[self.index_to_token[token]] - if self.index_to_token[token] in self.mapping.display - else self.index_to_token[token] - for token in hypothesis[0].tokens.tolist() - ] - ).strip() + "".join(hypothesis[0].words).replace(self.space_token, " ") for hypothesis in hypotheses ] - # Normalize confidence score out["confidence"] = [ - np.around( - np.exp( - hypothesis[0].score - / ((self.language_model_weight + 1) * length.item()) - ), - 2, - ) - for hypothesis, length in zip(hypotheses, batch_sizes) + np.exp(hypothesis[0].score / hypothesis[0].timesteps[-1].item()) + for hypothesis in hypotheses ] return out - def __call__( - self, batch_features: torch.FloatTensor, batch_frames: torch.LongTensor - ) -> Dict[str, List[Union[str, float]]]: + def __call__(self, batch_features, batch_sizes): """ Decode a feature vector using n-gram language modelling. - :param batch_features: Feature vector of size (batch_size, n_tokens, n_frames). - :param batch_frames: Prediction length of size batch_size. - :return: A dictionary containing the hypotheses and their confidences. + Args: + features (Any): feature vector of size (n_frame, batch_size, n_tokens). + Can be either a torch.tensor or a torch.nn.utils.rnn.PackedSequence + Returns: + out (Dict[str, List]): a dictionary containing the hypothesis (the list of decoded tokens). + There is no character-based probability. """ - # Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens) + # Reshape from (n_frame, batch_size, n_tokens) to (batch_size, n_frame, n_tokens) batch_features = batch_features.permute((0, 2, 1)) - # Insert CTC frames to avoid getting rid of duplicates - # Make sure that the CTC token has low probs for other frames - batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames) + # Apply temperature scaling + batch_features = batch_features / self.temperature # Apply log softmax - batch_features = torch.nn.functional.log_softmax( - batch_features / self.temperature, dim=-1 - ) + batch_features = torch.nn.functional.log_softmax(batch_features, dim=-1) + # batch_features = self.add_ctc_frames(batch_features) + # batch_sizes = batch_features.shape[0] # No GPU support for torchaudio's ctc_decoder - batch_features = batch_features.to(self.device) - batch_frames = batch_frames.to(self.device) + device = torch.device("cpu") + batch_features = batch_features.to(device) + if isinstance(batch_sizes, list): + batch_sizes = torch.tensor(batch_sizes) + batch_sizes.to(device) # Decode - hypotheses = self.decoder(batch_features, batch_frames) - return self.post_process(hypotheses, batch_frames) + hypotheses = self.decoder(batch_features, batch_sizes) + return self.post_process(hypotheses) diff --git a/dan/ocr/predict/__init__.py b/dan/ocr/predict/__init__.py index e5f7b2bc..da477c1f 100644 --- a/dan/ocr/predict/__init__.py +++ b/dan/ocr/predict/__init__.py @@ -169,7 +169,7 @@ def add_predict_parser(subcommands) -> None: ) parser.add_argument( "--use-language-model", - help="Whether to use an explicit language model to rescore text hypotheses.", + help="Whether to use an explicit language model to rescore text hypothesis.", action="store_true", required=False, ) diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index e33d0ba2..a2096015 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -77,6 +77,16 @@ class DAN: decoder = GlobalHTADecoder(parameters["decoder"]).to(self.device) decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=True) + self.lm_decoder = CTCLanguageDecoder( + language_model_path=parameters["lm_decoder"]["language_model_path"], + lexicon_path=parameters["lm_decoder"]["lexicon_path"], + tokens_path=parameters["lm_decoder"]["tokens_path"], + language_model_weight=parameters["lm_decoder"]["language_model_weight"], + blank_token=parameters["lm_decoder"]["blank_token"], + unk_token=parameters["lm_decoder"]["unk_token"], + sil_token=parameters["lm_decoder"]["sil_token"], + ) + logger.debug(f"Loaded model {model_path}") if mode == "train": @@ -179,7 +189,6 @@ class DAN: (batch_size,), dtype=torch.int, device=self.device ) - # end token index will be used for ctc tot_pred = torch.zeros( (batch_size, len(self.charset) + 1, self.max_chars), dtype=torch.float, @@ -270,7 +279,7 @@ class DAN: out["text"] = predicted_text if use_language_model: - out["language_model"] = self.lm_decoder(tot_pred, prediction_len) + out["language_model"] = self.lm_decoder(tot_pred, predicted_tokens_len) if confidences: out["confidences"] = confidence_scores if attentions: @@ -466,7 +475,7 @@ def run( :param batch_size: Size of the batches for prediction. :param tokens: NER tokens used. :param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages. - :param use_language_model: Whether to use an explicit language model to rescore text hypotheses. + :param use_language_model: Whether to use an explicit language model to rescore text hypothesis. """ # Create output directory if necessary if not output.exists(): diff --git a/dan/utils.py b/dan/utils.py index 176c56c9..69e7d82a 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -163,9 +163,7 @@ def read_json(json_path: str) -> Dict: def read_txt(txt_path: str) -> str: """ - Read TXT file. - :param txt_path: Path of the text file to read. - :return: The content of the read file. + Read TXT file """ filename = Path(txt_path) assert filename.exists(), f"{txt_path} does not resolve." -- GitLab