Skip to content
Snippets Groups Projects
Commit afde811f authored by Solene Tarride's avatar Solene Tarride
Browse files

Use named tuple for special tokens

parent a278056d
No related branches found
No related tags found
No related merge requests found
...@@ -36,13 +36,13 @@ from dan.datasets.extract.utils import ( ...@@ -36,13 +36,13 @@ from dan.datasets.extract.utils import (
normalize_linebreaks, normalize_linebreaks,
normalize_spaces, normalize_spaces,
) )
from dan.utils import LM_MAPPING, EntityType, parse_tokens
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import ( from line_image_extractor.image_utils import (
BoundingBox, BoundingBox,
Extraction, Extraction,
polygon_to_bbox, polygon_to_bbox,
) )
from dan.utils import EntityType, LMTokenMapping, parse_tokens
from line_image_extractor.extractor import extract
IMAGES_DIR = "images" # Subpath to the images directory. IMAGES_DIR = "images" # Subpath to the images directory.
LANGUAGE_DIR = "language_model" # Subpath to the language model directory. LANGUAGE_DIR = "language_model" # Subpath to the language model directory.
...@@ -281,7 +281,7 @@ class ArkindexExtractor: ...@@ -281,7 +281,7 @@ class ArkindexExtractor:
""" """
return " ".join( return " ".join(
[ [
LM_MAPPING[token] if token in LM_MAPPING else token self.mapping.encode[token] if token in self.mapping else token
for token in list(text.strip()) for token in list(text.strip())
] ]
) )
...@@ -370,14 +370,14 @@ class ArkindexExtractor: ...@@ -370,14 +370,14 @@ class ArkindexExtractor:
""" """
for token in sorted(list(self.charset)): for token in sorted(list(self.charset)):
assert ( assert (
token not in LM_MAPPING.values() token not in self.mapping.encode.values()
), f"Special token {token} is reserved for language modeling." ), f"Special token {token} is reserved for language modeling."
self.language_tokens.append( self.language_tokens.append(
LM_MAPPING[token] self.mapping.encode[token]
) if token in LM_MAPPING else self.language_tokens.append(token) ) if token in self.mapping.encode else self.language_tokens.append(token)
# Add the special blank token # Add the special blank token
self.language_tokens.append(LM_MAPPING["<ctc>"]) self.language_tokens.append(self.mapping.ctc.encoded)
# Build lexicon # Build lexicon
assert all( assert all(
......
...@@ -7,7 +7,7 @@ from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, Modu ...@@ -7,7 +7,7 @@ from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, Modu
from torch.nn.init import xavier_uniform_ from torch.nn.init import xavier_uniform_
from torchaudio.models.decoder import ctc_decoder from torchaudio.models.decoder import ctc_decoder
from dan.utils import LM_MAPPING, read_txt from dan.utils import LMTokenMapping, read_txt
class PositionalEncoding1D(Module): class PositionalEncoding1D(Module):
...@@ -489,23 +489,21 @@ class CTCLanguageDecoder: ...@@ -489,23 +489,21 @@ class CTCLanguageDecoder:
language_model_weight: float = 1.0, language_model_weight: float = 1.0,
temperature: float = 1.0, temperature: float = 1.0,
): ):
self.space_token = LM_MAPPING[" "] self.mapping = LMTokenMapping()
self.unknown_token = LM_MAPPING["<unk>"]
self.blank_token = LM_MAPPING["<ctc>"]
self.language_model_weight = language_model_weight self.language_model_weight = language_model_weight
self.temperature = temperature self.temperature = temperature
self.tokens_to_index = { self.tokens_to_index = {
token: i for i, token in enumerate(read_txt(tokens_path).split("\n")) token: i for i, token in enumerate(read_txt(tokens_path).split("\n"))
} }
self.blank_token_id = self.tokens_to_index[self.blank_token] self.blank_token_id = self.tokens_to_index[self.mapping.ctc.encoded]
self.decoder = ctc_decoder( self.decoder = ctc_decoder(
lm=language_model_path, lm=language_model_path,
lexicon=lexicon_path, lexicon=lexicon_path,
tokens=tokens_path, tokens=tokens_path,
lm_weight=self.language_model_weight, lm_weight=self.language_model_weight,
blank_token=self.blank_token, blank_token=self.mapping.ctc.encoded,
unk_word=self.unknown_token, unk_word=self.mapping.unknown.encoded,
sil_token=self.space_token, sil_token=self.mapping.space.encoded,
nbest=1, nbest=1,
) )
# No GPU support # No GPU support
...@@ -550,7 +548,14 @@ class CTCLanguageDecoder: ...@@ -550,7 +548,14 @@ class CTCLanguageDecoder:
out = {} out = {}
# Replace <space> by an actual space and format string # Replace <space> by an actual space and format string
out["text"] = [ out["text"] = [
"".join(hypothesis[0].words).replace(self.space_token, " ") "".join(
[
self.mapping.display[token]
if token in self.mapping.display
else token
for token in hypothesis[0].words
]
)
for hypothesis in hypotheses for hypothesis in hypotheses
] ]
# Normalize confidence score # Normalize confidence score
......
...@@ -16,12 +16,24 @@ class MLflowNotInstalled(Exception): ...@@ -16,12 +16,24 @@ class MLflowNotInstalled(Exception):
""" """
LM_MAPPING = { class Token(NamedTuple):
" ": "", encoded: str
"\n": "", display: str
"<ctc>": "",
"<unk>": "",
} class LMTokenMapping(NamedTuple):
space: Token = Token("", " ")
linebreak: Token = Token("", "\n")
ctc: Token = Token("", "<ctc>")
unknown: Token = Token("", "<unk>")
@property
def display(self):
return {a.encoded: a.display for a in self}
@property
def encode(self):
return {a.display: a.encoded for a in self}
class EntityType(NamedTuple): class EntityType(NamedTuple):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment