Skip to content
Snippets Groups Projects
Commit e9aa1084 authored by Solene Tarride's avatar Solene Tarride
Browse files

Add CTC frame between each frames

parent 16d83adb
No related branches found
No related tags found
No related merge requests found
......@@ -516,17 +516,19 @@ class CTCLanguageDecoder:
# column with 1 probability on CTC token
ctc_probs = (
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * 0.1 / n_tokens
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32)
* 0.01
/ (n_tokens - 1)
)
ctc_probs[:, :, self.ctc_id] = 0.9
ctc_probs[:, :, self.ctc_id] = 0.99
ctc_probs = ctc_probs.log()
for i in range(n_frames - 1):
batch_features = torch.cat(
[
batch_features[:, 2 * i + 1 :, :],
ctc_probs,
batch_features[:, : 2 * i + 1, :],
ctc_probs,
batch_features[:, 2 * i + 1 :, :],
],
dim=1,
)
......@@ -534,14 +536,15 @@ class CTCLanguageDecoder:
def post_process(self, hypotheses):
"""
Post-process hypotheses to output JSON
Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
"""
out = {}
# Export only the best hypothesis
# Replace <space> by an actual space and format string
out["text"] = [
"".join(hypothesis[0].words).replace(self.space_token, " ")
for hypothesis in hypotheses
]
# Normalize confidence score
out["confidence"] = [
np.exp(hypothesis[0].score / hypothesis[0].timesteps[-1].item())
for hypothesis in hypotheses
......@@ -552,13 +555,12 @@ class CTCLanguageDecoder:
"""
Decode a feature vector using n-gram language modelling.
Args:
features (Any): feature vector of size (n_frame, batch_size, n_tokens).
Can be either a torch.tensor or a torch.nn.utils.rnn.PackedSequence
features (torch.tensor): feature vector of size (batch_size, n_tokens, n_frame).
batch_sizes (Union[List, torch.tensor]): actual length of predictions
Returns:
out (Dict[str, List]): a dictionary containing the hypothesis (the list of decoded tokens).
There is no character-based probability.
out (Dict[List]): a dictionary containing the hypotheses.
"""
# Reshape from (n_frame, batch_size, n_tokens) to (batch_size, n_frame, n_tokens)
# Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens)
batch_features = batch_features.permute((0, 2, 1))
# Apply temperature scaling
......@@ -566,8 +568,8 @@ class CTCLanguageDecoder:
# Apply log softmax
batch_features = torch.nn.functional.log_softmax(batch_features, dim=-1)
# batch_features = self.add_ctc_frames(batch_features)
# batch_sizes = batch_features.shape[0]
batch_features = self.add_ctc_frames(batch_features)
batch_sizes *= 2
# No GPU support for torchaudio's ctc_decoder
device = torch.device("cpu")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment