Skip to content
Snippets Groups Projects
Commit e09a5ec4 authored by Solene Tarride's avatar Solene Tarride
Browse files

Move tensor to correct device and trim prediction

parent e935baaf
No related branches found
No related tags found
No related merge requests found
...@@ -519,12 +519,22 @@ class CTCLanguageDecoder: ...@@ -519,12 +519,22 @@ class CTCLanguageDecoder:
batch_size, n_frames, n_tokens = batch_features.shape batch_size, n_frames, n_tokens = batch_features.shape
# Reset probabilities for the CTC token # Reset probabilities for the CTC token
batch_features[:, :, -1] = ( batch_features[:, :, -1] = (
torch.ones((batch_size, n_frames), dtype=torch.float32) * low_prob torch.ones(
(batch_size, n_frames),
dtype=torch.float32,
device=batch_features.device,
)
* low_prob
) )
# Create a frame with high probability CTC token # Create a frame with high probability CTC token
ctc_probs = ( ctc_probs = (
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * low_prob torch.ones(
(batch_size, 1, n_tokens),
dtype=torch.float32,
device=batch_features.device,
)
* low_prob
) )
ctc_probs[:, :, self.blank_token_id] = high_prob ctc_probs[:, :, self.blank_token_id] = high_prob
ctc_probs = ctc_probs ctc_probs = ctc_probs
...@@ -558,7 +568,7 @@ class CTCLanguageDecoder: ...@@ -558,7 +568,7 @@ class CTCLanguageDecoder:
else self.index_to_token[token] else self.index_to_token[token]
for token in hypothesis[0].tokens.tolist() for token in hypothesis[0].tokens.tolist()
] ]
) ).strip()
for hypothesis in hypotheses for hypothesis in hypotheses
] ]
# Normalize confidence score # Normalize confidence score
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment