From bffe2c05fbd7d85bf0db34f29aa4ee5c7bb55697 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Sat, 23 Sep 2023 09:17:26 +0200 Subject: [PATCH] Fix CTC token probability --- dan/ocr/decoder.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index 1b8b89d2..b9b256d1 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -513,20 +513,23 @@ class CTCLanguageDecoder: """ Add CTC frames between each characters to avoid duplicate removal """ - batch_size, _, n_tokens = batch_features.shape + high_prob = batch_features.max() + low_prob = batch_features.min() + batch_size, n_frames, n_tokens = batch_features.shape - # Create tensor with high probability CTC token - high_prob = 0.99 - low_prob = 1 - high_prob + # Reset probabilities for the CTC token + batch_features[:, :, -1] = ( + torch.ones((batch_size, n_frames), dtype=torch.float32) * low_prob + ) + + # Create a frame with high probability CTC token ctc_probs = ( - torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) - * low_prob - / (n_tokens - 1) + torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * low_prob ) ctc_probs[:, :, self.blank_token_id] = high_prob - ctc_probs = ctc_probs.log() + ctc_probs = ctc_probs - # Insert CTC tensor between frames + # Insert the CTC frame between regular frames for fn in range(batch_frames.max() - 1): batch_features = torch.cat( [ @@ -579,11 +582,14 @@ class CTCLanguageDecoder: # Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens) batch_features = batch_features.permute((0, 2, 1)) + # Insert CTC frames to avoid getting rid of duplicates + # Make sure that the CTC token has low probs for other frames + batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames) + # Apply log softmax batch_features = torch.nn.functional.log_softmax( batch_features / self.temperature, dim=-1 ) - batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames) # No GPU support for torchaudio's ctc_decoder batch_features = batch_features.to(self.device) -- GitLab