Skip to content
Snippets Groups Projects

Support subword and word language models

Merged Solene Tarride requested to merge subword-and-word-lm into main
All threads resolved!
1 file
+ 16
10
Compare changes
  • Side-by-side
  • Inline
+ 16
10
@@ -513,20 +513,23 @@ class CTCLanguageDecoder:
"""
Add CTC frames between each characters to avoid duplicate removal
"""
batch_size, _, n_tokens = batch_features.shape
high_prob = batch_features.max()
low_prob = batch_features.min()
batch_size, n_frames, n_tokens = batch_features.shape
# Create tensor with high probability CTC token
high_prob = 0.99
low_prob = 1 - high_prob
# Reset probabilities for the CTC token
batch_features[:, :, -1] = (
torch.ones((batch_size, n_frames), dtype=torch.float32) * low_prob
)
# Create a frame with high probability CTC token
ctc_probs = (
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32)
* low_prob
/ (n_tokens - 1)
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * low_prob
)
ctc_probs[:, :, self.blank_token_id] = high_prob
ctc_probs = ctc_probs.log()
ctc_probs = ctc_probs
# Insert CTC tensor between frames
# Insert the CTC frame between regular frames
for fn in range(batch_frames.max() - 1):
batch_features = torch.cat(
[
@@ -579,11 +582,14 @@ class CTCLanguageDecoder:
# Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens)
batch_features = batch_features.permute((0, 2, 1))
# Insert CTC frames to avoid getting rid of duplicates
# Make sure that the CTC token has low probs for other frames
batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames)
# Apply log softmax
batch_features = torch.nn.functional.log_softmax(
batch_features / self.temperature, dim=-1
)
batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames)
# No GPU support for torchaudio's ctc_decoder
batch_features = batch_features.to(self.device)
Loading