From 6942bd0d5df20da39d9e3c6e1ba25bc324c55256 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Wed, 20 Sep 2023 11:54:50 +0200 Subject: [PATCH] Use named tuple for special tokens --- dan/datasets/extract/arkindex.py | 14 +++++++------- dan/ocr/decoder.py | 23 ++++++++++++++--------- dan/utils.py | 24 ++++++++++++++++++------ 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py index 7e187156..e3c76034 100644 --- a/dan/datasets/extract/arkindex.py +++ b/dan/datasets/extract/arkindex.py @@ -36,13 +36,13 @@ from dan.datasets.extract.utils import ( normalize_linebreaks, normalize_spaces, ) -from dan.utils import LM_MAPPING, EntityType, parse_tokens -from line_image_extractor.extractor import extract from line_image_extractor.image_utils import ( BoundingBox, Extraction, polygon_to_bbox, ) +from dan.utils import EntityType, LMTokenMapping, parse_tokens +from line_image_extractor.extractor import extract IMAGES_DIR = "images" # Subpath to the images directory. LANGUAGE_DIR = "language_model" # Subpath to the language model directory. @@ -281,7 +281,7 @@ class ArkindexExtractor: """ return " ".join( [ - LM_MAPPING[token] if token in LM_MAPPING else token + self.mapping.encode[token] if token in self.mapping else token for token in list(text.strip()) ] ) @@ -370,14 +370,14 @@ class ArkindexExtractor: """ for token in sorted(list(self.charset)): assert ( - token not in LM_MAPPING.values() + token not in self.mapping.encode.values() ), f"Special token {token} is reserved for language modeling." self.language_tokens.append( - LM_MAPPING[token] - ) if token in LM_MAPPING else self.language_tokens.append(token) + self.mapping.encode[token] + ) if token in self.mapping.encode else self.language_tokens.append(token) # Add the special blank token - self.language_tokens.append(LM_MAPPING["<ctc>"]) + self.language_tokens.append(self.mapping.ctc.encoded) # Build lexicon assert all( diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index 6f635ee0..568a7cc1 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -7,7 +7,7 @@ from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, Modu from torch.nn.init import xavier_uniform_ from torchaudio.models.decoder import ctc_decoder -from dan.utils import LM_MAPPING, read_txt +from dan.utils import LMTokenMapping, read_txt class PositionalEncoding1D(Module): @@ -489,23 +489,21 @@ class CTCLanguageDecoder: language_model_weight: float = 1.0, temperature: float = 1.0, ): - self.space_token = LM_MAPPING[" "] - self.unknown_token = LM_MAPPING["<unk>"] - self.blank_token = LM_MAPPING["<ctc>"] + self.mapping = LMTokenMapping() self.language_model_weight = language_model_weight self.temperature = temperature self.tokens_to_index = { token: i for i, token in enumerate(read_txt(tokens_path).split("\n")) } - self.blank_token_id = self.tokens_to_index[self.blank_token] + self.blank_token_id = self.tokens_to_index[self.mapping.ctc.encoded] self.decoder = ctc_decoder( lm=language_model_path, lexicon=lexicon_path, tokens=tokens_path, lm_weight=self.language_model_weight, - blank_token=self.blank_token, - unk_word=self.unknown_token, - sil_token=self.space_token, + blank_token=self.mapping.ctc.encoded, + unk_word=self.mapping.unknown.encoded, + sil_token=self.mapping.space.encoded, nbest=1, ) # No GPU support @@ -550,7 +548,14 @@ class CTCLanguageDecoder: out = {} # Replace <space> by an actual space and format string out["text"] = [ - "".join(hypothesis[0].words).replace(self.space_token, " ") + "".join( + [ + self.mapping.display[token] + if token in self.mapping.display + else token + for token in hypothesis[0].words + ] + ) for hypothesis in hypotheses ] # Normalize confidence score diff --git a/dan/utils.py b/dan/utils.py index 97135eba..93e62d5b 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -16,12 +16,24 @@ class MLflowNotInstalled(Exception): """ -LM_MAPPING = { - " ": "⎵", - "\n": "↵", - "<ctc>": "â—Œ", - "<unk>": "â‡", -} +class Token(NamedTuple): + encoded: str + display: str + + +class LMTokenMapping(NamedTuple): + space: Token = Token("⎵", " ") + linebreak: Token = Token("↵", "\n") + ctc: Token = Token("â—Œ", "<ctc>") + unknown: Token = Token("â‡", "<unk>") + + @property + def display(self): + return {a.encoded: a.display for a in self} + + @property + def encode(self): + return {a.display: a.encoded for a in self} class EntityType(NamedTuple): -- GitLab