diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index 1b8b89d22c8bb2af66b67888c981778edf9a8e63..b9b256d1c28ca72c5d9c280f1feec88d2cf9bf26 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -513,20 +513,23 @@ class CTCLanguageDecoder: """ Add CTC frames between each characters to avoid duplicate removal """ - batch_size, _, n_tokens = batch_features.shape + high_prob = batch_features.max() + low_prob = batch_features.min() + batch_size, n_frames, n_tokens = batch_features.shape - # Create tensor with high probability CTC token - high_prob = 0.99 - low_prob = 1 - high_prob + # Reset probabilities for the CTC token + batch_features[:, :, -1] = ( + torch.ones((batch_size, n_frames), dtype=torch.float32) * low_prob + ) + + # Create a frame with high probability CTC token ctc_probs = ( - torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) - * low_prob - / (n_tokens - 1) + torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * low_prob ) ctc_probs[:, :, self.blank_token_id] = high_prob - ctc_probs = ctc_probs.log() + ctc_probs = ctc_probs - # Insert CTC tensor between frames + # Insert the CTC frame between regular frames for fn in range(batch_frames.max() - 1): batch_features = torch.cat( [ @@ -579,11 +582,14 @@ class CTCLanguageDecoder: # Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens) batch_features = batch_features.permute((0, 2, 1)) + # Insert CTC frames to avoid getting rid of duplicates + # Make sure that the CTC token has low probs for other frames + batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames) + # Apply log softmax batch_features = torch.nn.functional.log_softmax( batch_features / self.temperature, dim=-1 ) - batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames) # No GPU support for torchaudio's ctc_decoder batch_features = batch_features.to(self.device)