Skip to content
Snippets Groups Projects

Add Language Model Decoder

Merged Solene Tarride requested to merge lm-decoder into main
All threads resolved!
2 files
+ 47
41
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -91,6 +91,7 @@ class DAN:
self.encoder = encoder
self.decoder = decoder
self.lm_decoder = None
if use_language_model:
self.lm_decoder = CTCLanguageDecoder(
@@ -98,9 +99,6 @@ class DAN:
lexicon_path=parameters["lm_decoder"]["lexicon_path"],
tokens_path=parameters["lm_decoder"]["tokens_path"],
language_model_weight=parameters["lm_decoder"]["language_model_weight"],
blank_token=parameters["lm_decoder"]["blank_token"],
unk_token=parameters["lm_decoder"]["unk_token"],
sil_token=parameters["lm_decoder"]["sil_token"],
)
self.mean, self.std = (
@@ -178,6 +176,7 @@ class DAN:
(batch_size,), dtype=torch.int, device=self.device
)
# end token index will be used for ctc
tot_pred = torch.zeros(
(batch_size, len(self.charset) + 1, self.max_chars),
dtype=torch.float,
@@ -268,7 +267,7 @@ class DAN:
out["text"] = predicted_text
if use_language_model:
out["language_model"] = self.lm_decoder(tot_pred, predicted_tokens_len)
out["language_model"] = self.lm_decoder(tot_pred, prediction_len)
if confidences:
out["confidences"] = confidence_scores
if attentions:
@@ -509,6 +508,7 @@ def run(
dan_model.load(
model, parameters, charset, mode="eval", use_language_model=use_language_model
)
batch_size = 1 if use_language_model else batch_size
images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_batch in list_to_batches(images, n=batch_size):
Loading