diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index 6dfda6d9fb39f2885f1d42b1ea39661bcb6b8a3f..8279d205f59cab24b4c9248efbd9f9b5e07f2e64 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -519,12 +519,22 @@ class CTCLanguageDecoder: batch_size, n_frames, n_tokens = batch_features.shape # Reset probabilities for the CTC token batch_features[:, :, -1] = ( - torch.ones((batch_size, n_frames), dtype=torch.float32) * low_prob + torch.ones( + (batch_size, n_frames), + dtype=torch.float32, + device=batch_features.device, + ) + * low_prob ) # Create a frame with high probability CTC token ctc_probs = ( - torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * low_prob + torch.ones( + (batch_size, 1, n_tokens), + dtype=torch.float32, + device=batch_features.device, + ) + * low_prob ) ctc_probs[:, :, self.blank_token_id] = high_prob ctc_probs = ctc_probs @@ -558,7 +568,7 @@ class CTCLanguageDecoder: else self.index_to_token[token] for token in hypothesis[0].tokens.tolist() ] - ) + ).strip() for hypothesis in hypotheses ] # Normalize confidence score