From 1cccd99b715f2d954f68ea6fcb82b3740b50ed33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Tue, 19 Sep 2023 12:41:40 +0200 Subject: [PATCH] Fix shape of tot_prob --- dan/ocr/decoder.py | 80 +++++++++++++++++++----------------- dan/ocr/predict/inference.py | 14 +++++-- 2 files changed, 53 insertions(+), 41 deletions(-) diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index 59909d14..bd1044d7 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -7,7 +7,7 @@ from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, Modu from torch.nn.init import xavier_uniform_ from torchaudio.models.decoder import ctc_decoder -from dan.utils import read_txt +from dan.utils import LM_MAPPING, read_txt class PositionalEncoding1D(Module): @@ -487,56 +487,65 @@ class CTCLanguageDecoder: lexicon_path: str, tokens_path: str, language_model_weight: float = 1.0, - blank_token: str = "<ctc>", - unk_token: str = "<unk>", - sil_token: str = "<space>", temperature: float = 1.0, ): + self.space_token = LM_MAPPING[" "] + self.unknown_token = LM_MAPPING["<unk>"] + self.blank_token = LM_MAPPING["<ctc>"] + self.language_model_weight = language_model_weight + self.temperature = temperature + self.tokens_to_index = { + token: i for i, token in enumerate(read_txt(tokens_path).split("\n")) + } + self.blank_token_id = self.tokens_to_index[self.blank_token] self.decoder = ctc_decoder( lm=language_model_path, lexicon=lexicon_path, tokens=tokens_path, - lm_weight=language_model_weight, - blank_token=blank_token, - unk_word=unk_token, - sil_token=sil_token, + lm_weight=self.language_model_weight, + blank_token=self.blank_token, + unk_word=self.unknown_token, + sil_token=self.space_token, nbest=1, ) - self.temperature = temperature - self.space_token = sil_token - self.tokens_to_idx = read_txt(tokens_path).split("\n") - self.blank_id = self.tokens_to_idx.index(blank_token) + # No GPU support + self.device = torch.device("cpu") - def add_ctc_frames(self, batch_features): + def add_ctc_frames(self, batch_features, batch_frames): """ Add CTC frames between each characters to avoid duplicate removal """ - batch_size, n_frames, n_tokens = batch_features.shape + batch_size, _, n_tokens = batch_features.shape + torch.clone(batch_features) + # visualize_debug(batch_features.exp()[0, :batch_frames[0], :].numpy(), "probs.jpg", False) # Create tensor with high probability CTC token high_prob = 0.99 - low_prob = 0.01 + low_prob = 1 - high_prob ctc_probs = ( torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * low_prob / (n_tokens - 1) ) - ctc_probs[:, :, self.blank_id] = high_prob + ctc_probs[:, :, self.blank_token_id] = high_prob ctc_probs = ctc_probs.log() # Insert CTC tensor between frames - for i in range(n_frames): + for fn in range(batch_frames[0] - 1): batch_features = torch.cat( [ - batch_features[:, : 2 * i + 1, :], + batch_features[:, : 2 * fn + 1, :], ctc_probs, - batch_features[:, 2 * i + 1 :, :], + batch_features[:, 2 * fn + 1 :, :], ], dim=1, ) - return batch_features - def post_process(self, hypotheses): + # Update the number of frames + batch_frames = 2 * batch_frames - 1 + return batch_features, batch_frames + + def post_process(self, hypotheses, batch_sizes): """ Post-process hypotheses to output JSON. Exports only the best hypothesis for each image. """ @@ -548,12 +557,14 @@ class CTCLanguageDecoder: ] # Normalize confidence score out["confidence"] = [ - np.exp(hypothesis[0].score / hypothesis[0].timesteps[-1].item()) - for hypothesis in hypotheses + np.exp( + hypothesis[0].score / ((self.language_model_weight + 1) * length.item()) + ) + for hypothesis, length in zip(hypotheses, batch_sizes) ] return out - def __call__(self, batch_features, batch_sizes): + def __call__(self, batch_features, batch_frames): """ Decode a feature vector using n-gram language modelling. Args: @@ -565,21 +576,16 @@ class CTCLanguageDecoder: # Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens) batch_features = batch_features.permute((0, 2, 1)) - # Apply temperature scaling - batch_features = batch_features / self.temperature - # Apply log softmax - batch_features = torch.nn.functional.log_softmax(batch_features, dim=-1) - batch_features = self.add_ctc_frames(batch_features) - batch_sizes *= 2 + batch_features = torch.nn.functional.log_softmax( + batch_features / self.temperature, dim=-1 + ) + batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames) # No GPU support for torchaudio's ctc_decoder - device = torch.device("cpu") - batch_features = batch_features.to(device) - if isinstance(batch_sizes, list): - batch_sizes = torch.tensor(batch_sizes) - batch_sizes.to(device) + batch_features = batch_features.to(self.device) + batch_frames = batch_frames.to(self.device) # Decode - hypotheses = self.decoder(batch_features, batch_sizes) - return self.post_process(hypotheses) + hypotheses = self.decoder(batch_features, batch_frames) + return self.post_process(hypotheses, batch_frames) diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index 4153633b..b04995d0 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -90,6 +90,7 @@ class DAN: self.encoder = encoder self.decoder = decoder + self.lm_decoder = None if use_language_model: self.lm_decoder = CTCLanguageDecoder( @@ -97,9 +98,6 @@ class DAN: lexicon_path=parameters["lm_decoder"]["lexicon_path"], tokens_path=parameters["lm_decoder"]["tokens_path"], language_model_weight=parameters["lm_decoder"]["language_model_weight"], - blank_token=parameters["lm_decoder"]["blank_token"], - unk_token=parameters["lm_decoder"]["unk_token"], - sil_token=parameters["lm_decoder"]["sil_token"], ) self.mean, self.std = ( @@ -178,6 +176,7 @@ class DAN: (batch_size,), dtype=torch.int, device=self.device ) + # end token index will be used for ctc tot_pred = torch.zeros( (batch_size, len(self.charset) + 1, self.max_chars), dtype=torch.float, @@ -268,7 +267,7 @@ class DAN: out["text"] = predicted_text if use_language_model: - out["language_model"] = self.lm_decoder(tot_pred, predicted_tokens_len) + out["language_model"] = self.lm_decoder(tot_pred, prediction_len) if confidences: out["confidences"] = confidence_scores if attentions: @@ -474,7 +473,14 @@ def run( cuda_device = f":{gpu_device}" if gpu_device is not None else "" device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" dan_model = DAN(device, temperature) +<<<<<<< HEAD dan_model.load(model, parameters, charset, mode="eval") +======= + dan_model.load( + model, parameters, charset, mode="eval", use_language_model=use_language_model + ) + batch_size = 1 if use_language_model else batch_size +>>>>>>> e7c611f (Fix shape of tot_prob) images = image_dir.rglob(f"*{image_extension}") if not image else [image] for image_batch in list_to_batches(images, n=batch_size): -- GitLab