From e09a5ec431502e23c226197ac3f2abe12deadc88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Mon, 25 Sep 2023 16:40:29 +0200 Subject: [PATCH] Move tensor to correct device and trim prediction --- dan/ocr/decoder.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index 6dfda6d9..8279d205 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -519,12 +519,22 @@ class CTCLanguageDecoder: batch_size, n_frames, n_tokens = batch_features.shape # Reset probabilities for the CTC token batch_features[:, :, -1] = ( - torch.ones((batch_size, n_frames), dtype=torch.float32) * low_prob + torch.ones( + (batch_size, n_frames), + dtype=torch.float32, + device=batch_features.device, + ) + * low_prob ) # Create a frame with high probability CTC token ctc_probs = ( - torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * low_prob + torch.ones( + (batch_size, 1, n_tokens), + dtype=torch.float32, + device=batch_features.device, + ) + * low_prob ) ctc_probs[:, :, self.blank_token_id] = high_prob ctc_probs = ctc_probs @@ -558,7 +568,7 @@ class CTCLanguageDecoder: else self.index_to_token[token] for token in hypothesis[0].tokens.tolist() ] - ) + ).strip() for hypothesis in hypotheses ] # Normalize confidence score -- GitLab