diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py index 7e187156e465ee7d62864f1a1d4969c2726f89d0..e3c7603430811231b99e6ac7a84e207e02876d37 100644 --- a/dan/datasets/extract/arkindex.py +++ b/dan/datasets/extract/arkindex.py @@ -36,13 +36,13 @@ from dan.datasets.extract.utils import ( normalize_linebreaks, normalize_spaces, ) -from dan.utils import LM_MAPPING, EntityType, parse_tokens -from line_image_extractor.extractor import extract from line_image_extractor.image_utils import ( BoundingBox, Extraction, polygon_to_bbox, ) +from dan.utils import EntityType, LMTokenMapping, parse_tokens +from line_image_extractor.extractor import extract IMAGES_DIR = "images" # Subpath to the images directory. LANGUAGE_DIR = "language_model" # Subpath to the language model directory. @@ -281,7 +281,7 @@ class ArkindexExtractor: """ return " ".join( [ - LM_MAPPING[token] if token in LM_MAPPING else token + self.mapping.encode[token] if token in self.mapping else token for token in list(text.strip()) ] ) @@ -370,14 +370,14 @@ class ArkindexExtractor: """ for token in sorted(list(self.charset)): assert ( - token not in LM_MAPPING.values() + token not in self.mapping.encode.values() ), f"Special token {token} is reserved for language modeling." self.language_tokens.append( - LM_MAPPING[token] - ) if token in LM_MAPPING else self.language_tokens.append(token) + self.mapping.encode[token] + ) if token in self.mapping.encode else self.language_tokens.append(token) # Add the special blank token - self.language_tokens.append(LM_MAPPING["<ctc>"]) + self.language_tokens.append(self.mapping.ctc.encoded) # Build lexicon assert all( diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index 6f635ee0a216f4e5788796a959090be9000d1049..568a7cc1a867e2a2ba56ff398e135646d848a5af 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -7,7 +7,7 @@ from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, Modu from torch.nn.init import xavier_uniform_ from torchaudio.models.decoder import ctc_decoder -from dan.utils import LM_MAPPING, read_txt +from dan.utils import LMTokenMapping, read_txt class PositionalEncoding1D(Module): @@ -489,23 +489,21 @@ class CTCLanguageDecoder: language_model_weight: float = 1.0, temperature: float = 1.0, ): - self.space_token = LM_MAPPING[" "] - self.unknown_token = LM_MAPPING["<unk>"] - self.blank_token = LM_MAPPING["<ctc>"] + self.mapping = LMTokenMapping() self.language_model_weight = language_model_weight self.temperature = temperature self.tokens_to_index = { token: i for i, token in enumerate(read_txt(tokens_path).split("\n")) } - self.blank_token_id = self.tokens_to_index[self.blank_token] + self.blank_token_id = self.tokens_to_index[self.mapping.ctc.encoded] self.decoder = ctc_decoder( lm=language_model_path, lexicon=lexicon_path, tokens=tokens_path, lm_weight=self.language_model_weight, - blank_token=self.blank_token, - unk_word=self.unknown_token, - sil_token=self.space_token, + blank_token=self.mapping.ctc.encoded, + unk_word=self.mapping.unknown.encoded, + sil_token=self.mapping.space.encoded, nbest=1, ) # No GPU support @@ -550,7 +548,14 @@ class CTCLanguageDecoder: out = {} # Replace <space> by an actual space and format string out["text"] = [ - "".join(hypothesis[0].words).replace(self.space_token, " ") + "".join( + [ + self.mapping.display[token] + if token in self.mapping.display + else token + for token in hypothesis[0].words + ] + ) for hypothesis in hypotheses ] # Normalize confidence score diff --git a/dan/utils.py b/dan/utils.py index 97135ebaf1b7370d3a9fdac1e53c38660886dcde..93e62d5bac348098f7c3451af34edc2a5d8f3fab 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -16,12 +16,24 @@ class MLflowNotInstalled(Exception): """ -LM_MAPPING = { - " ": "⎵", - "\n": "↵", - "<ctc>": "â—Œ", - "<unk>": "â‡", -} +class Token(NamedTuple): + encoded: str + display: str + + +class LMTokenMapping(NamedTuple): + space: Token = Token("⎵", " ") + linebreak: Token = Token("↵", "\n") + ctc: Token = Token("â—Œ", "<ctc>") + unknown: Token = Token("â‡", "<unk>") + + @property + def display(self): + return {a.encoded: a.display for a in self} + + @property + def encode(self): + return {a.display: a.encoded for a in self} class EntityType(NamedTuple):