Skip to content
Snippets Groups Projects
Commit 2538568c authored by Solene Tarride's avatar Solene Tarride
Browse files

Implement CTCLanguageDecoder

parent 095667f4
No related branches found
No related tags found
1 merge request!287Support subword and word language models
# -*- coding: utf-8 -*-
from typing import Dict, List, Union
import numpy as np
import torch
from torch import relu, softmax
from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList
from torch.nn.init import xavier_uniform_
from torchaudio.models.decoder import CTCHypothesis, ctc_decoder
from torchaudio.models.decoder import ctc_decoder
from dan.utils import LMTokenMapping, read_txt
from dan.utils import read_txt
class PositionalEncoding1D(Module):
......@@ -470,13 +468,17 @@ class GlobalHTADecoder(Module):
class CTCLanguageDecoder:
"""
Initialize a CTC decoder with n-gram language modeling.
:param language_model_path: Path to a KenLM or ARPA language model.
:param lexicon_path: Path to a lexicon file containing the possible words and corresponding spellings.
Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free decoding.
:param tokens_path: Path to a file containing valid tokens. If using a file, the expected
format is for tokens mapping to the same index to be on the same line.
:param language_model_weight: Weight of the language model.
:param temperature: Temperature for model calibreation.
Args:
language_model_path (str): path to a KenLM or ARPA language model
lexicon_path (str): path to a lexicon file containing the possible words and corresponding spellings.
Each line consists of a word and its space separated spelling. If `None`, uses lexicon-free
decoding.
tokens_path (str): path to a file containing valid tokens. If using a file, the expected
format is for tokens mapping to the same index to be on the same line
language_model_weight (float): weight of the language model.
blank_token (str): token representing the blank/ctc symbol
unk_token (str): token representing unknown characters
sil_token (str): token representing the space character
"""
def __init__(
......@@ -485,138 +487,95 @@ class CTCLanguageDecoder:
lexicon_path: str,
tokens_path: str,
language_model_weight: float = 1.0,
blank_token: str = "<ctc>",
unk_token: str = "<unk>",
sil_token: str = "<space>",
temperature: float = 1.0,
):
self.mapping = LMTokenMapping()
self.language_model_weight = language_model_weight
self.temperature = temperature
self.tokens_to_index = {
token: i for i, token in enumerate(read_txt(tokens_path).split("\n"))
}
self.index_to_token = {i: token for token, i in self.tokens_to_index.items()}
self.blank_token_id = self.tokens_to_index[self.mapping.ctc.encoded]
# Torchaudio's decoder
# https://pytorch.org/audio/master/generated/torchaudio.models.decoder.ctc_decoder.html
self.decoder = ctc_decoder(
lm=language_model_path,
lexicon=lexicon_path,
tokens=tokens_path,
lm_weight=self.language_model_weight,
blank_token=self.mapping.ctc.encoded,
sil_token=self.mapping.space.encoded,
unk_word="",
lm_weight=language_model_weight,
blank_token=blank_token,
unk_word=unk_token,
sil_token=sil_token,
nbest=1,
)
# No GPU support
self.device = torch.device("cpu")
self.temperature = temperature
def add_ctc_frames(
self, batch_features: torch.FloatTensor, batch_frames: torch.LongTensor
) -> tuple[torch.FloatTensor, torch.LongTensor]:
self.tokens_to_idx = read_txt(tokens_path).split("\n")
self.ctc_id = self.tokens_to_idx.index(blank_token)
self.space_token = sil_token
def add_ctc_frames(self, batch_features):
"""
Add CTC frames between each characters to avoid duplicate removal.
Add CTC frames between each characters to avoid duplicate removal
"""
high_prob = batch_features.max()
low_prob = batch_features.min()
batch_size, n_frames, n_tokens = batch_features.shape
# Reset probabilities for the CTC token
batch_features[:, :, -1] = (
torch.ones(
(batch_size, n_frames),
dtype=torch.float32,
device=batch_features.device,
)
* low_prob
)
# Create a frame with high probability CTC token
# column with 1 probability on CTC token
ctc_probs = (
torch.ones(
(batch_size, 1, n_tokens),
dtype=torch.float32,
device=batch_features.device,
)
* low_prob
torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * 0.1 / n_tokens
)
ctc_probs[:, :, self.blank_token_id] = high_prob
ctc_probs = ctc_probs
ctc_probs[:, :, self.ctc_id] = 0.9
ctc_probs = ctc_probs.log()
# Insert the CTC frame between regular frames
for fn in range(batch_frames.max() - 1):
for i in range(n_frames - 1):
batch_features = torch.cat(
[
batch_features[:, : 2 * fn + 1, :],
batch_features[:, 2 * i + 1 :, :],
ctc_probs,
batch_features[:, 2 * fn + 1 :, :],
batch_features[:, : 2 * i + 1, :],
],
dim=1,
)
return batch_features
# Update the number of frames
batch_frames = 2 * batch_frames - 1
return batch_features, batch_frames
def post_process(
self, hypotheses: List[CTCHypothesis], batch_sizes: torch.LongTensor
) -> Dict[str, List[Union[str, float]]]:
def post_process(self, hypotheses):
"""
Post-process hypotheses to output JSON. Exports only the best hypothesis for each image.
:param hypotheses: List of hypotheses returned by the decoder.
:param batch_sizes: Prediction length of size batch_size.
:return: A dictionary containing the hypotheses and their confidences.
Post-process hypotheses to output JSON
"""
out = {}
# Replace <space> by an actual space and format string
# Export only the best hypothesis
out["text"] = [
"".join(
[
self.mapping.display[self.index_to_token[token]]
if self.index_to_token[token] in self.mapping.display
else self.index_to_token[token]
for token in hypothesis[0].tokens.tolist()
]
).strip()
"".join(hypothesis[0].words).replace(self.space_token, " ")
for hypothesis in hypotheses
]
# Normalize confidence score
out["confidence"] = [
np.around(
np.exp(
hypothesis[0].score
/ ((self.language_model_weight + 1) * length.item())
),
2,
)
for hypothesis, length in zip(hypotheses, batch_sizes)
np.exp(hypothesis[0].score / hypothesis[0].timesteps[-1].item())
for hypothesis in hypotheses
]
return out
def __call__(
self, batch_features: torch.FloatTensor, batch_frames: torch.LongTensor
) -> Dict[str, List[Union[str, float]]]:
def __call__(self, batch_features, batch_sizes):
"""
Decode a feature vector using n-gram language modelling.
:param batch_features: Feature vector of size (batch_size, n_tokens, n_frames).
:param batch_frames: Prediction length of size batch_size.
:return: A dictionary containing the hypotheses and their confidences.
Args:
features (Any): feature vector of size (n_frame, batch_size, n_tokens).
Can be either a torch.tensor or a torch.nn.utils.rnn.PackedSequence
Returns:
out (Dict[str, List]): a dictionary containing the hypothesis (the list of decoded tokens).
There is no character-based probability.
"""
# Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens)
# Reshape from (n_frame, batch_size, n_tokens) to (batch_size, n_frame, n_tokens)
batch_features = batch_features.permute((0, 2, 1))
# Insert CTC frames to avoid getting rid of duplicates
# Make sure that the CTC token has low probs for other frames
batch_features, batch_frames = self.add_ctc_frames(batch_features, batch_frames)
# Apply temperature scaling
batch_features = batch_features / self.temperature
# Apply log softmax
batch_features = torch.nn.functional.log_softmax(
batch_features / self.temperature, dim=-1
)
batch_features = torch.nn.functional.log_softmax(batch_features, dim=-1)
# batch_features = self.add_ctc_frames(batch_features)
# batch_sizes = batch_features.shape[0]
# No GPU support for torchaudio's ctc_decoder
batch_features = batch_features.to(self.device)
batch_frames = batch_frames.to(self.device)
device = torch.device("cpu")
batch_features = batch_features.to(device)
if isinstance(batch_sizes, list):
batch_sizes = torch.tensor(batch_sizes)
batch_sizes.to(device)
# Decode
hypotheses = self.decoder(batch_features, batch_frames)
return self.post_process(hypotheses, batch_frames)
hypotheses = self.decoder(batch_features, batch_sizes)
return self.post_process(hypotheses)
......@@ -169,7 +169,7 @@ def add_predict_parser(subcommands) -> None:
)
parser.add_argument(
"--use-language-model",
help="Whether to use an explicit language model to rescore text hypotheses.",
help="Whether to use an explicit language model to rescore text hypothesis.",
action="store_true",
required=False,
)
......
......@@ -77,6 +77,16 @@ class DAN:
decoder = GlobalHTADecoder(parameters["decoder"]).to(self.device)
decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=True)
self.lm_decoder = CTCLanguageDecoder(
language_model_path=parameters["lm_decoder"]["language_model_path"],
lexicon_path=parameters["lm_decoder"]["lexicon_path"],
tokens_path=parameters["lm_decoder"]["tokens_path"],
language_model_weight=parameters["lm_decoder"]["language_model_weight"],
blank_token=parameters["lm_decoder"]["blank_token"],
unk_token=parameters["lm_decoder"]["unk_token"],
sil_token=parameters["lm_decoder"]["sil_token"],
)
logger.debug(f"Loaded model {model_path}")
if mode == "train":
......@@ -179,7 +189,6 @@ class DAN:
(batch_size,), dtype=torch.int, device=self.device
)
# end token index will be used for ctc
tot_pred = torch.zeros(
(batch_size, len(self.charset) + 1, self.max_chars),
dtype=torch.float,
......@@ -270,7 +279,7 @@ class DAN:
out["text"] = predicted_text
if use_language_model:
out["language_model"] = self.lm_decoder(tot_pred, prediction_len)
out["language_model"] = self.lm_decoder(tot_pred, predicted_tokens_len)
if confidences:
out["confidences"] = confidence_scores
if attentions:
......@@ -466,7 +475,7 @@ def run(
:param batch_size: Size of the batches for prediction.
:param tokens: NER tokens used.
:param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.
:param use_language_model: Whether to use an explicit language model to rescore text hypotheses.
:param use_language_model: Whether to use an explicit language model to rescore text hypothesis.
"""
# Create output directory if necessary
if not output.exists():
......
......@@ -163,9 +163,7 @@ def read_json(json_path: str) -> Dict:
def read_txt(txt_path: str) -> str:
"""
Read TXT file.
:param txt_path: Path of the text file to read.
:return: The content of the read file.
Read TXT file
"""
filename = Path(txt_path)
assert filename.exists(), f"{txt_path} does not resolve."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment