Skip to content
Snippets Groups Projects
Commit 3af7de77 authored by Solene Tarride's avatar Solene Tarride
Browse files

Move tensor to correct device and trim prediction

parent dd279dab
No related branches found
No related tags found
No related merge requests found
......@@ -519,12 +519,22 @@ class CTCLanguageDecoder:
batch_size, n_frames, n_tokens = batch_features.shape
# Reset probabilities for the CTC token
batch_features[:, :, -1] = (
torch.ones((batch_size, n_frames), dtype=torch.float32) * low_prob
torch.ones(
(batch_size, n_frames),
dtype=torch.float32,
device=batch_features.device,
)
* low_prob
)
# Create a frame with high probability CTC token
ctc_probs = (
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * low_prob
torch.ones(
(batch_size, 1, n_tokens),
dtype=torch.float32,
device=batch_features.device,
)
* low_prob
)
ctc_probs[:, :, self.blank_token_id] = high_prob
ctc_probs = ctc_probs
......@@ -558,7 +568,7 @@ class CTCLanguageDecoder:
else self.index_to_token[token]
for token in hypothesis[0].tokens.tolist()
]
)
).strip()
for hypothesis in hypotheses
]
# Normalize confidence score
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment