Skip to content
Snippets Groups Projects
Commit b26de6ed authored by Solene Tarride's avatar Solene Tarride
Browse files

Move tensor to correct device and trim prediction

parent dcef6444
No related branches found
No related tags found
No related merge requests found
......@@ -519,12 +519,22 @@ class CTCLanguageDecoder:
batch_size, n_frames, n_tokens = batch_features.shape
# Reset probabilities for the CTC token
batch_features[:, :, -1] = (
torch.ones((batch_size, n_frames), dtype=torch.float32) * low_prob
torch.ones(
(batch_size, n_frames),
dtype=torch.float32,
device=batch_features.device,
)
* low_prob
)
# Create a frame with high probability CTC token
ctc_probs = (
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * low_prob
torch.ones(
(batch_size, 1, n_tokens),
dtype=torch.float32,
device=batch_features.device,
)
* low_prob
)
ctc_probs[:, :, self.blank_token_id] = high_prob
ctc_probs = ctc_probs
......@@ -558,7 +568,7 @@ class CTCLanguageDecoder:
else self.index_to_token[token]
for token in hypothesis[0].tokens.tolist()
]
)
).strip()
for hypothesis in hypotheses
]
# Normalize confidence score
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment