From e9c3ac5b8f1fb6ac218aceca661e26ca2e7cffea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Fri, 15 Sep 2023 15:00:50 +0200 Subject: [PATCH] Generate LM files during data extraction --- dan/datasets/extract/extract.py | 27 +++++++++++++++++---------- dan/datasets/extract/utils.py | 16 +++++++--------- dan/utils.py | 26 ++++++-------------------- 3 files changed, 30 insertions(+), 39 deletions(-) diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py index a7fa8d32..c4d71632 100644 --- a/dan/datasets/extract/extract.py +++ b/dan/datasets/extract/extract.py @@ -34,9 +34,9 @@ from dan.datasets.extract.utils import ( get_bbox, insert_token, normalize_linebreaks, - normalize_spaces, + normalize_spaces ) -from dan.utils import EntityType, LMTokenMapping, parse_tokens +from dan.utils import LM_MAPPING, EntityType, parse_tokens from line_image_extractor.extractor import extract from line_image_extractor.image_utils import ( BoundingBox, @@ -279,7 +279,12 @@ class ArkindexExtractor: """ Format text for the language model. Return the text tokenized at character-level. """ - return " ".join(map(self.mapping.encode_token, list(text.strip()))) + return " ".join( + [ + LM_MAPPING[token] if token in LM_MAPPING else token + for token in list(text.strip()) + ] + ) def process_element( self, @@ -318,6 +323,8 @@ class ArkindexExtractor: self.data[split][str(image_path)] = text self.charset = self.charset.union(set(text)) + if split == "train": + self.language_corpus.append(self.format_text_language_model(text)) # Language model should be built using only text from the training set if split == "train": @@ -363,14 +370,14 @@ class ArkindexExtractor: """ for token in sorted(list(self.charset)): assert ( - token not in self.mapping.encode.values() + token not in LM_MAPPING.values() ), f"Special token {token} is reserved for language modeling." self.language_tokens.append( - self.mapping.encode[token] - ) if token in self.mapping.encode else self.language_tokens.append(token) + LM_MAPPING[token] + ) if token in LM_MAPPING else self.language_tokens.append(token) # Add the special blank token - self.language_tokens.append(self.mapping.ctc.encoded) + self.language_tokens.append(LM_MAPPING["<ctc>"]) # Build lexicon assert all( @@ -386,13 +393,13 @@ class ArkindexExtractor: indent=4, ) ) - (self.output / "language_model" / "corpus.txt").write_text( + (self.output / "language_corpus.txt").write_text( "\n".join(self.language_corpus) ) - (self.output / "language_model" / "tokens.txt").write_text( + (self.output / "language_tokens.txt").write_text( "\n".join(self.language_tokens) ) - (self.output / "language_model" / "lexicon.txt").write_text( + (self.output / "language_lexicon.txt").write_text( "\n".join(self.language_lexicon) ) (self.output / "charset.pkl").write_bytes( diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index 79e5ef6f..2863410d 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -21,8 +21,8 @@ logger = logging.getLogger(__name__) DOWNLOAD_TIMEOUT = (30, 60) # replace \t with regular space and consecutive spaces -TRIM_SPACE_REGEX = re.compile(r"[\t ]+") -TRIM_RETURN_REGEX = re.compile(r"[\r\n]+") +TRIM_SPACE_REGEX = re.compile(r"[\t| ]+") +TRIM_RETURN_REGEX = re.compile(r"[\r|\n]+") def _retry_log(retry_state, *args, **kwargs): @@ -89,21 +89,18 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) - + (entity_type.end if entity_type else "") ) - def normalize_linebreaks(text: str) -> str: """ - Remove begin/ending linebreaks. - Replace \r with regular linebreak and consecutive linebreaks. - :param text: Text to normalize. + Remove begin/ending linebreaks + Replace \r with regular linebreak and consecutive linebreaks """ return TRIM_RETURN_REGEX.sub("\n", text.strip()) def normalize_spaces(text: str) -> str: """ - Remove begin/ending spaces. - Replace \t with regular space and consecutive spaces. - :param text: Text to normalize. + Remove begin/ending spaces + Replace \t with regular space and consecutive spaces """ return TRIM_SPACE_REGEX.sub(" ", text.strip()) @@ -117,3 +114,4 @@ def get_bbox(polygon: List[List[int]]) -> str: x, y = min(all_x), min(all_y) width, height = max(all_x) - x, max(all_y) - y return ",".join(list(map(str, [int(x), int(y), int(width), int(height)]))) + diff --git a/dan/utils.py b/dan/utils.py index 69e7d82a..97135eba 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -16,26 +16,12 @@ class MLflowNotInstalled(Exception): """ -class Token(NamedTuple): - encoded: str - display: str - - -class LMTokenMapping(NamedTuple): - space: Token = Token("⎵", " ") - linebreak: Token = Token("↵", "\n") - ctc: Token = Token("â—Œ", "<ctc>") - - @property - def display(self): - return {a.encoded: a.display for a in self} - - @property - def encode(self): - return {a.display: a.encoded for a in self} - - def encode_token(self, token: str) -> str: - return self.encode.get(token, token) +LM_MAPPING = { + " ": "⎵", + "\n": "↵", + "<ctc>": "â—Œ", + "<unk>": "â‡", +} class EntityType(NamedTuple): -- GitLab