From d419f3773de09f6ffeeab8d1c45c01b0c97648a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Fri, 15 Sep 2023 15:00:50 +0200 Subject: [PATCH] Generate LM files during data extraction --- dan/datasets/extract/extract.py | 56 ++++++++++++++++++++++++++++++--- dan/datasets/extract/utils.py | 21 ++++++++++--- dan/utils.py | 8 +++++ 3 files changed, 76 insertions(+), 9 deletions(-) diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py index 0251e654..e337fa06 100644 --- a/dan/datasets/extract/extract.py +++ b/dan/datasets/extract/extract.py @@ -33,9 +33,10 @@ from dan.datasets.extract.utils import ( download_image, get_bbox, insert_token, - remove_spaces, + normalize_linebreaks, + normalize_spaces ) -from dan.utils import EntityType, parse_tokens +from dan.utils import LM_MAPPING, EntityType, parse_tokens from line_image_extractor.extractor import extract from line_image_extractor.image_utils import ( BoundingBox, @@ -93,6 +94,9 @@ class ArkindexExtractor: self.data: Dict = defaultdict(dict) self.charset = set() + self.language_corpus = [] + self.language_tokens = [] + self.language_lexicon = [] # Image download tasks to process self.tasks: List[Dict[str, str]] = [] @@ -254,7 +258,8 @@ class ArkindexExtractor: def format_text(self, text: str, charset: Optional[set] = None): if not self.keep_spaces: - text = remove_spaces(text) + text = normalize_spaces(text) + text = normalize_linebreaks(text) # Replace unknown characters by the unknown token if charset is not None: @@ -265,9 +270,19 @@ class ArkindexExtractor: for unknown_char in unknown_charset } ) - return text.strip() + def format_text_language_model(self, text: str): + """ + Format text for the language model. Return the text tokenized at character-level. + """ + return " ".join( + [ + LM_MAPPING[token] if token in LM_MAPPING else token + for token in list(text.strip()) + ] + ) + def process_element( self, element: Element, @@ -305,6 +320,8 @@ class ArkindexExtractor: self.data[split][str(image_path)] = text self.charset = self.charset.union(set(text)) + if split == "train": + self.language_corpus.append(self.format_text_language_model(text)) def process_parent( self, @@ -340,6 +357,27 @@ class ArkindexExtractor: except ProcessingError as e: logger.warning(f"Skipping {element.id}: {str(e)}") + def format_lm_files(self) -> None: + """ + Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset. + """ + for token in sorted(list(self.charset)): + assert ( + token not in LM_MAPPING.values() + ), f"Special token {token} is reserved for language modeling." + self.language_tokens.append( + LM_MAPPING[token] + ) if token in LM_MAPPING else self.language_tokens.append(token) + + # Add the special blank token + self.language_tokens.append(LM_MAPPING["<ctc>"]) + + # Build lexicon + assert all( + [len(token) == 1 for token in self.language_lexicon] + ), "Tokens should be single characters." + self.language_lexicon = [f"{token} {token}" for token in self.language_tokens] + def export(self): (self.output / "labels.json").write_text( json.dumps( @@ -348,6 +386,15 @@ class ArkindexExtractor: indent=4, ) ) + (self.output / "language_corpus.txt").write_text( + "\n".join(self.language_corpus) + ) + (self.output / "language_tokens.txt").write_text( + "\n".join(self.language_tokens) + ) + (self.output / "language_lexicon.txt").write_text( + "\n".join(self.language_lexicon) + ) (self.output / "charset.pkl").write_bytes( pickle.dumps(sorted(list(self.charset))) ) @@ -408,6 +455,7 @@ class ArkindexExtractor: pbar.refresh() self.download_images() + self.format_lm_files() self.export() diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index e6fb5296..e178c6dd 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -21,7 +21,8 @@ logger = logging.getLogger(__name__) DOWNLOAD_TIMEOUT = (30, 60) # replace \t with regular space and consecutive spaces -TRIM_REGEX = re.compile(r"\t?(?: +)") +TRIM_SPACE_REGEX = re.compile(r"[\t| ]+") +TRIM_RETURN_REGEX = re.compile(r"[\r|\n]+") def _retry_log(retry_state, *args, **kwargs): @@ -80,11 +81,20 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) - + (entity_type.end if entity_type else "") ) +def normalize_linebreaks(text: str) -> str: + """ + Remove begin/ending linebreaks + Replace \r with regular linebreak and consecutive linebreaks + """ + return TRIM_RETURN_REGEX.sub("\n", text.strip()) + -def remove_spaces(text: str) -> str: - # remove begin/ending spaces - # replace \t with regular space and consecutive spaces - return TRIM_REGEX.sub(" ", text.strip()) +def normalize_spaces(text: str) -> str: + """ + Remove begin/ending spaces + Replace \t with regular space and consecutive spaces + """ + return TRIM_SPACE_REGEX.sub(" ", text.strip()) def get_bbox(polygon: List[List[int]]) -> str: @@ -96,3 +106,4 @@ def get_bbox(polygon: List[List[int]]) -> str: x, y = min(all_x), min(all_y) width, height = max(all_x) - x, max(all_y) - y return ",".join(list(map(str, [int(x), int(y), int(width), int(height)]))) + diff --git a/dan/utils.py b/dan/utils.py index f813723e..97135eba 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -16,6 +16,14 @@ class MLflowNotInstalled(Exception): """ +LM_MAPPING = { + " ": "⎵", + "\n": "↵", + "<ctc>": "â—Œ", + "<unk>": "â‡", +} + + class EntityType(NamedTuple): start: str end: str = "" -- GitLab