Skip to content
Snippets Groups Projects

Add Language Model Decoder

Merged Solene Tarride requested to merge lm-decoder into main
All threads resolved!
1 file
+ 15
13
Compare changes
  • Side-by-side
  • Inline
+ 15
13
@@ -516,17 +516,19 @@ class CTCLanguageDecoder:
# column with 1 probability on CTC token
ctc_probs = (
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * 0.1 / n_tokens
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32)
* 0.01
/ (n_tokens - 1)
)
ctc_probs[:, :, self.ctc_id] = 0.9
ctc_probs[:, :, self.ctc_id] = 0.99
ctc_probs = ctc_probs.log()
for i in range(n_frames - 1):
batch_features = torch.cat(
[
batch_features[:, 2 * i + 1 :, :],
ctc_probs,
batch_features[:, : 2 * i + 1, :],
ctc_probs,
batch_features[:, 2 * i + 1 :, :],
],
dim=1,
)
@@ -534,14 +536,15 @@ class CTCLanguageDecoder:
def post_process(self, hypotheses):
"""
Post-process hypotheses to output JSON
Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
"""
out = {}
# Export only the best hypothesis
# Replace <space> by an actual space and format string
out["text"] = [
"".join(hypothesis[0].words).replace(self.space_token, " ")
for hypothesis in hypotheses
]
# Normalize confidence score
out["confidence"] = [
np.exp(hypothesis[0].score / hypothesis[0].timesteps[-1].item())
for hypothesis in hypotheses
@@ -552,13 +555,12 @@ class CTCLanguageDecoder:
"""
Decode a feature vector using n-gram language modelling.
Args:
features (Any): feature vector of size (n_frame, batch_size, n_tokens).
Can be either a torch.tensor or a torch.nn.utils.rnn.PackedSequence
features (torch.tensor): feature vector of size (batch_size, n_tokens, n_frame).
batch_sizes (Union[List, torch.tensor]): actual length of predictions
Returns:
out (Dict[str, List]): a dictionary containing the hypothesis (the list of decoded tokens).
There is no character-based probability.
out (Dict[List]): a dictionary containing the hypotheses.
"""
# Reshape from (n_frame, batch_size, n_tokens) to (batch_size, n_frame, n_tokens)
# Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens)
batch_features = batch_features.permute((0, 2, 1))
# Apply temperature scaling
@@ -566,8 +568,8 @@ class CTCLanguageDecoder:
# Apply log softmax
batch_features = torch.nn.functional.log_softmax(batch_features, dim=-1)
# batch_features = self.add_ctc_frames(batch_features)
# batch_sizes = batch_features.shape[0]
batch_features = self.add_ctc_frames(batch_features)
batch_sizes *= 2
# No GPU support for torchaudio's ctc_decoder
device = torch.device("cpu")
Loading