Skip to content
Snippets Groups Projects
Commit 074f0b39 authored by Solene Tarride's avatar Solene Tarride
Browse files

Add CTC frame between each frames

parent c5f5a535
No related branches found
No related tags found
1 merge request!287Support subword and word language models
......@@ -516,17 +516,19 @@ class CTCLanguageDecoder:
# column with 1 probability on CTC token
ctc_probs = (
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * 0.1 / n_tokens
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32)
* 0.01
/ (n_tokens - 1)
)
ctc_probs[:, :, self.ctc_id] = 0.9
ctc_probs[:, :, self.ctc_id] = 0.99
ctc_probs = ctc_probs.log()
for i in range(n_frames - 1):
batch_features = torch.cat(
[
batch_features[:, 2 * i + 1 :, :],
ctc_probs,
batch_features[:, : 2 * i + 1, :],
ctc_probs,
batch_features[:, 2 * i + 1 :, :],
],
dim=1,
)
......@@ -534,14 +536,15 @@ class CTCLanguageDecoder:
def post_process(self, hypotheses):
"""
Post-process hypotheses to output JSON
Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
"""
out = {}
# Export only the best hypothesis
# Replace <space> by an actual space and format string
out["text"] = [
"".join(hypothesis[0].words).replace(self.space_token, " ")
for hypothesis in hypotheses
]
# Normalize confidence score
out["confidence"] = [
np.exp(hypothesis[0].score / hypothesis[0].timesteps[-1].item())
for hypothesis in hypotheses
......@@ -552,13 +555,12 @@ class CTCLanguageDecoder:
"""
Decode a feature vector using n-gram language modelling.
Args:
features (Any): feature vector of size (n_frame, batch_size, n_tokens).
Can be either a torch.tensor or a torch.nn.utils.rnn.PackedSequence
features (torch.tensor): feature vector of size (batch_size, n_tokens, n_frame).
batch_sizes (Union[List, torch.tensor]): actual length of predictions
Returns:
out (Dict[str, List]): a dictionary containing the hypothesis (the list of decoded tokens).
There is no character-based probability.
out (Dict[List]): a dictionary containing the hypotheses.
"""
# Reshape from (n_frame, batch_size, n_tokens) to (batch_size, n_frame, n_tokens)
# Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens)
batch_features = batch_features.permute((0, 2, 1))
# Apply temperature scaling
......@@ -566,8 +568,8 @@ class CTCLanguageDecoder:
# Apply log softmax
batch_features = torch.nn.functional.log_softmax(batch_features, dim=-1)
# batch_features = self.add_ctc_frames(batch_features)
# batch_sizes = batch_features.shape[0]
batch_features = self.add_ctc_frames(batch_features)
batch_sizes *= 2
# No GPU support for torchaudio's ctc_decoder
device = torch.device("cpu")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment