Skip to content
Snippets Groups Projects
Commit f0448bf8 authored by Solene Tarride's avatar Solene Tarride
Browse files

Use named tuple for special tokens

parent 08603b35
No related branches found
No related tags found
No related merge requests found
......@@ -36,13 +36,13 @@ from dan.datasets.extract.utils import (
normalize_linebreaks,
normalize_spaces,
)
from dan.utils import LM_MAPPING, EntityType, parse_tokens
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
BoundingBox,
Extraction,
polygon_to_bbox,
)
from dan.utils import EntityType, LMTokenMapping, parse_tokens
from line_image_extractor.extractor import extract
IMAGES_DIR = "images" # Subpath to the images directory.
......@@ -88,6 +88,7 @@ class ArkindexExtractor:
self.max_width = max_width
self.max_height = max_height
self.image_extension = image_extension
self.mapping = LMTokenMapping()
self.keep_spaces = keep_spaces
......@@ -275,7 +276,7 @@ class ArkindexExtractor:
"""
return " ".join(
[
LM_MAPPING[token] if token in LM_MAPPING else token
self.mapping.encode[token] if token in self.mapping else token
for token in list(text.strip())
]
)
......@@ -360,14 +361,14 @@ class ArkindexExtractor:
"""
for token in sorted(list(self.charset)):
assert (
token not in LM_MAPPING.values()
token not in self.mapping.encode.values()
), f"Special token {token} is reserved for language modeling."
self.language_tokens.append(
LM_MAPPING[token]
) if token in LM_MAPPING else self.language_tokens.append(token)
self.mapping.encode[token]
) if token in self.mapping.encode else self.language_tokens.append(token)
# Add the special blank token
self.language_tokens.append(LM_MAPPING["<ctc>"])
self.language_tokens.append(self.mapping.ctc.encoded)
# Build lexicon
assert all(
......
......@@ -7,7 +7,7 @@ from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, Modu
from torch.nn.init import xavier_uniform_
from torchaudio.models.decoder import ctc_decoder
from dan.utils import LM_MAPPING, read_txt
from dan.utils import LMTokenMapping, read_txt
class PositionalEncoding1D(Module):
......@@ -489,23 +489,21 @@ class CTCLanguageDecoder:
language_model_weight: float = 1.0,
temperature: float = 1.0,
):
self.space_token = LM_MAPPING[" "]
self.unknown_token = LM_MAPPING["<unk>"]
self.blank_token = LM_MAPPING["<ctc>"]
self.mapping = LMTokenMapping()
self.language_model_weight = language_model_weight
self.temperature = temperature
self.tokens_to_index = {
token: i for i, token in enumerate(read_txt(tokens_path).split("\n"))
}
self.blank_token_id = self.tokens_to_index[self.blank_token]
self.blank_token_id = self.tokens_to_index[self.mapping.ctc.encoded]
self.decoder = ctc_decoder(
lm=language_model_path,
lexicon=lexicon_path,
tokens=tokens_path,
lm_weight=self.language_model_weight,
blank_token=self.blank_token,
unk_word=self.unknown_token,
sil_token=self.space_token,
blank_token=self.mapping.ctc.encoded,
unk_word=self.mapping.unknown.encoded,
sil_token=self.mapping.space.encoded,
nbest=1,
)
# No GPU support
......@@ -550,7 +548,14 @@ class CTCLanguageDecoder:
out = {}
# Replace <space> by an actual space and format string
out["text"] = [
"".join(hypothesis[0].words).replace(self.space_token, " ")
"".join(
[
self.mapping.display[token]
if token in self.mapping.display
else token
for token in hypothesis[0].words
]
)
for hypothesis in hypotheses
]
# Normalize confidence score
......
......@@ -16,12 +16,24 @@ class MLflowNotInstalled(Exception):
"""
LM_MAPPING = {
" ": "",
"\n": "",
"<ctc>": "",
"<unk>": "",
}
class Token(NamedTuple):
encoded: str
display: str
class LMTokenMapping(NamedTuple):
space: Token = Token("", " ")
linebreak: Token = Token("", "\n")
ctc: Token = Token("", "<ctc>")
unknown: Token = Token("", "<unk>")
@property
def display(self):
return {a.encoded: a.display for a in self}
@property
def encode(self):
return {a.display: a.encoded for a in self}
class EntityType(NamedTuple):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment