Skip to content
Snippets Groups Projects
Commit 1cccd99b authored by Solene Tarride's avatar Solene Tarride
Browse files

Fix shape of tot_prob

parent 1a96389b
No related branches found
No related tags found
1 merge request!287Support subword and word language models
......@@ -7,7 +7,7 @@ from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, Modu
from torch.nn.init import xavier_uniform_
from torchaudio.models.decoder import ctc_decoder
from dan.utils import read_txt
from dan.utils import LM_MAPPING, read_txt
class PositionalEncoding1D(Module):
......@@ -487,56 +487,65 @@ class CTCLanguageDecoder:
lexicon_path: str,
tokens_path: str,
language_model_weight: float = 1.0,
blank_token: str = "<ctc>",
unk_token: str = "<unk>",
sil_token: str = "<space>",
temperature: float = 1.0,
):
self.space_token = LM_MAPPING[" "]
self.unknown_token = LM_MAPPING["<unk>"]
self.blank_token = LM_MAPPING["<ctc>"]
self.language_model_weight = language_model_weight
self.temperature = temperature
self.tokens_to_index = {
token: i for i, token in enumerate(read_txt(tokens_path).split("\n"))
}
self.blank_token_id = self.tokens_to_index[self.blank_token]
self.decoder = ctc_decoder(
lm=language_model_path,
lexicon=lexicon_path,
tokens=tokens_path,
lm_weight=language_model_weight,
blank_token=blank_token,
unk_word=unk_token,
sil_token=sil_token,
lm_weight=self.language_model_weight,
blank_token=self.blank_token,
unk_word=self.unknown_token,
sil_token=self.space_token,
nbest=1,
)
self.temperature = temperature
self.space_token = sil_token
self.tokens_to_idx = read_txt(tokens_path).split("\n")
self.blank_id = self.tokens_to_idx.index(blank_token)
# No GPU support
self.device = torch.device("cpu")
def add_ctc_frames(self, batch_features):
def add_ctc_frames(self, batch_features, batch_frames):
"""
Add CTC frames between each characters to avoid duplicate removal
"""
batch_size, n_frames, n_tokens = batch_features.shape
batch_size, _, n_tokens = batch_features.shape
torch.clone(batch_features)
# visualize_debug(batch_features.exp()[0, :batch_frames[0], :].numpy(), "probs.jpg", False)
# Create tensor with high probability CTC token
high_prob = 0.99
low_prob = 0.01
low_prob = 1 - high_prob
ctc_probs = (
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32)
* low_prob
/ (n_tokens - 1)
)
ctc_probs[:, :, self.blank_id] = high_prob
ctc_probs[:, :, self.blank_token_id] = high_prob
ctc_probs = ctc_probs.log()
# Insert CTC tensor between frames
for i in range(n_frames):
for fn in range(batch_frames[0] - 1):
batch_features = torch.cat(
[
batch_features[:, : 2 * i + 1, :],
batch_features[:, : 2 * fn + 1, :],
ctc_probs,
batch_features[:, 2 * i + 1 :, :],
batch_features[:, 2 * fn + 1 :, :],
],
dim=1,
)
return batch_features
def post_process(self, hypotheses):
# Update the number of frames
batch_frames = 2 * batch_frames - 1
return batch_features, batch_frames
def post_process(self, hypotheses, batch_sizes):
"""
Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
"""
......@@ -548,12 +557,14 @@ class CTCLanguageDecoder:
]
# Normalize confidence score
out["confidence"] = [
np.exp(hypothesis[0].score / hypothesis[0].timesteps[-1].item())
for hypothesis in hypotheses
np.exp(
hypothesis[0].score / ((self.language_model_weight + 1) * length.item())
)
for hypothesis, length in zip(hypotheses, batch_sizes)
]
return out
def __call__(self, batch_features, batch_sizes):
def __call__(self, batch_features, batch_frames):
"""
Decode a feature vector using n-gram language modelling.
Args:
......@@ -565,21 +576,16 @@ class CTCLanguageDecoder:
# Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens)
batch_features = batch_features.permute((0, 2, 1))
# Apply temperature scaling
batch_features = batch_features / self.temperature
# Apply log softmax
batch_features = torch.nn.functional.log_softmax(batch_features, dim=-1)
batch_features = self.add_ctc_frames(batch_features)
batch_sizes *= 2
batch_features = torch.nn.functional.log_softmax(
batch_features / self.temperature, dim=-1
)
batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames)
# No GPU support for torchaudio's ctc_decoder
device = torch.device("cpu")
batch_features = batch_features.to(device)
if isinstance(batch_sizes, list):
batch_sizes = torch.tensor(batch_sizes)
batch_sizes.to(device)
batch_features = batch_features.to(self.device)
batch_frames = batch_frames.to(self.device)
# Decode
hypotheses = self.decoder(batch_features, batch_sizes)
return self.post_process(hypotheses)
hypotheses = self.decoder(batch_features, batch_frames)
return self.post_process(hypotheses, batch_frames)
......@@ -90,6 +90,7 @@ class DAN:
self.encoder = encoder
self.decoder = decoder
self.lm_decoder = None
if use_language_model:
self.lm_decoder = CTCLanguageDecoder(
......@@ -97,9 +98,6 @@ class DAN:
lexicon_path=parameters["lm_decoder"]["lexicon_path"],
tokens_path=parameters["lm_decoder"]["tokens_path"],
language_model_weight=parameters["lm_decoder"]["language_model_weight"],
blank_token=parameters["lm_decoder"]["blank_token"],
unk_token=parameters["lm_decoder"]["unk_token"],
sil_token=parameters["lm_decoder"]["sil_token"],
)
self.mean, self.std = (
......@@ -178,6 +176,7 @@ class DAN:
(batch_size,), dtype=torch.int, device=self.device
)
# end token index will be used for ctc
tot_pred = torch.zeros(
(batch_size, len(self.charset) + 1, self.max_chars),
dtype=torch.float,
......@@ -268,7 +267,7 @@ class DAN:
out["text"] = predicted_text
if use_language_model:
out["language_model"] = self.lm_decoder(tot_pred, predicted_tokens_len)
out["language_model"] = self.lm_decoder(tot_pred, prediction_len)
if confidences:
out["confidences"] = confidence_scores
if attentions:
......@@ -474,7 +473,14 @@ def run(
cuda_device = f":{gpu_device}" if gpu_device is not None else ""
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
<<<<<<< HEAD
dan_model.load(model, parameters, charset, mode="eval")
=======
dan_model.load(
model, parameters, charset, mode="eval", use_language_model=use_language_model
)
batch_size = 1 if use_language_model else batch_size
>>>>>>> e7c611f (Fix shape of tot_prob)
images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_batch in list_to_batches(images, n=batch_size):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment