Skip to content
Snippets Groups Projects
Commit 6942bd0d authored by Solene Tarride's avatar Solene Tarride
Browse files

Use named tuple for special tokens

parent fdb5ab7f
No related branches found
No related tags found
1 merge request!287Support subword and word language models
......@@ -36,13 +36,13 @@ from dan.datasets.extract.utils import (
normalize_linebreaks,
normalize_spaces,
)
from dan.utils import LM_MAPPING, EntityType, parse_tokens
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
BoundingBox,
Extraction,
polygon_to_bbox,
)
from dan.utils import EntityType, LMTokenMapping, parse_tokens
from line_image_extractor.extractor import extract
IMAGES_DIR = "images" # Subpath to the images directory.
LANGUAGE_DIR = "language_model" # Subpath to the language model directory.
......@@ -281,7 +281,7 @@ class ArkindexExtractor:
"""
return " ".join(
[
LM_MAPPING[token] if token in LM_MAPPING else token
self.mapping.encode[token] if token in self.mapping else token
for token in list(text.strip())
]
)
......@@ -370,14 +370,14 @@ class ArkindexExtractor:
"""
for token in sorted(list(self.charset)):
assert (
token not in LM_MAPPING.values()
token not in self.mapping.encode.values()
), f"Special token {token} is reserved for language modeling."
self.language_tokens.append(
LM_MAPPING[token]
) if token in LM_MAPPING else self.language_tokens.append(token)
self.mapping.encode[token]
) if token in self.mapping.encode else self.language_tokens.append(token)
# Add the special blank token
self.language_tokens.append(LM_MAPPING["<ctc>"])
self.language_tokens.append(self.mapping.ctc.encoded)
# Build lexicon
assert all(
......
......@@ -7,7 +7,7 @@ from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, Modu
from torch.nn.init import xavier_uniform_
from torchaudio.models.decoder import ctc_decoder
from dan.utils import LM_MAPPING, read_txt
from dan.utils import LMTokenMapping, read_txt
class PositionalEncoding1D(Module):
......@@ -489,23 +489,21 @@ class CTCLanguageDecoder:
language_model_weight: float = 1.0,
temperature: float = 1.0,
):
self.space_token = LM_MAPPING[" "]
self.unknown_token = LM_MAPPING["<unk>"]
self.blank_token = LM_MAPPING["<ctc>"]
self.mapping = LMTokenMapping()
self.language_model_weight = language_model_weight
self.temperature = temperature
self.tokens_to_index = {
token: i for i, token in enumerate(read_txt(tokens_path).split("\n"))
}
self.blank_token_id = self.tokens_to_index[self.blank_token]
self.blank_token_id = self.tokens_to_index[self.mapping.ctc.encoded]
self.decoder = ctc_decoder(
lm=language_model_path,
lexicon=lexicon_path,
tokens=tokens_path,
lm_weight=self.language_model_weight,
blank_token=self.blank_token,
unk_word=self.unknown_token,
sil_token=self.space_token,
blank_token=self.mapping.ctc.encoded,
unk_word=self.mapping.unknown.encoded,
sil_token=self.mapping.space.encoded,
nbest=1,
)
# No GPU support
......@@ -550,7 +548,14 @@ class CTCLanguageDecoder:
out = {}
# Replace <space> by an actual space and format string
out["text"] = [
"".join(hypothesis[0].words).replace(self.space_token, " ")
"".join(
[
self.mapping.display[token]
if token in self.mapping.display
else token
for token in hypothesis[0].words
]
)
for hypothesis in hypotheses
]
# Normalize confidence score
......
......@@ -16,12 +16,24 @@ class MLflowNotInstalled(Exception):
"""
LM_MAPPING = {
" ": "",
"\n": "",
"<ctc>": "",
"<unk>": "",
}
class Token(NamedTuple):
encoded: str
display: str
class LMTokenMapping(NamedTuple):
space: Token = Token("", " ")
linebreak: Token = Token("", "\n")
ctc: Token = Token("", "<ctc>")
unknown: Token = Token("", "<unk>")
@property
def display(self):
return {a.encoded: a.display for a in self}
@property
def encode(self):
return {a.display: a.encoded for a in self}
class EntityType(NamedTuple):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment