From d419f3773de09f6ffeeab8d1c45c01b0c97648a3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Fri, 15 Sep 2023 15:00:50 +0200
Subject: [PATCH] Generate LM files during data extraction

---
 dan/datasets/extract/extract.py | 56 ++++++++++++++++++++++++++++++---
 dan/datasets/extract/utils.py   | 21 ++++++++++---
 dan/utils.py                    |  8 +++++
 3 files changed, 76 insertions(+), 9 deletions(-)

diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py
index 0251e654..e337fa06 100644
--- a/dan/datasets/extract/extract.py
+++ b/dan/datasets/extract/extract.py
@@ -33,9 +33,10 @@ from dan.datasets.extract.utils import (
     download_image,
     get_bbox,
     insert_token,
-    remove_spaces,
+    normalize_linebreaks,
+    normalize_spaces
 )
-from dan.utils import EntityType, parse_tokens
+from dan.utils import LM_MAPPING, EntityType, parse_tokens
 from line_image_extractor.extractor import extract
 from line_image_extractor.image_utils import (
     BoundingBox,
@@ -93,6 +94,9 @@ class ArkindexExtractor:
 
         self.data: Dict = defaultdict(dict)
         self.charset = set()
+        self.language_corpus = []
+        self.language_tokens = []
+        self.language_lexicon = []
 
         # Image download tasks to process
         self.tasks: List[Dict[str, str]] = []
@@ -254,7 +258,8 @@ class ArkindexExtractor:
 
     def format_text(self, text: str, charset: Optional[set] = None):
         if not self.keep_spaces:
-            text = remove_spaces(text)
+            text = normalize_spaces(text)
+            text = normalize_linebreaks(text)
 
         # Replace unknown characters by the unknown token
         if charset is not None:
@@ -265,9 +270,19 @@ class ArkindexExtractor:
                     for unknown_char in unknown_charset
                 }
             )
-
         return text.strip()
 
+    def format_text_language_model(self, text: str):
+        """
+        Format text for the language model. Return the text tokenized at character-level.
+        """
+        return " ".join(
+            [
+                LM_MAPPING[token] if token in LM_MAPPING else token
+                for token in list(text.strip())
+            ]
+        )
+
     def process_element(
         self,
         element: Element,
@@ -305,6 +320,8 @@ class ArkindexExtractor:
 
         self.data[split][str(image_path)] = text
         self.charset = self.charset.union(set(text))
+        if split == "train":
+            self.language_corpus.append(self.format_text_language_model(text))
 
     def process_parent(
         self,
@@ -340,6 +357,27 @@ class ArkindexExtractor:
                 except ProcessingError as e:
                     logger.warning(f"Skipping {element.id}: {str(e)}")
 
+    def format_lm_files(self) -> None:
+        """
+        Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset.
+        """
+        for token in sorted(list(self.charset)):
+            assert (
+                token not in LM_MAPPING.values()
+            ), f"Special token {token} is reserved for language modeling."
+            self.language_tokens.append(
+                LM_MAPPING[token]
+            ) if token in LM_MAPPING else self.language_tokens.append(token)
+
+        # Add the special blank token
+        self.language_tokens.append(LM_MAPPING["<ctc>"])
+
+        # Build lexicon
+        assert all(
+            [len(token) == 1 for token in self.language_lexicon]
+        ), "Tokens should be single characters."
+        self.language_lexicon = [f"{token} {token}" for token in self.language_tokens]
+
     def export(self):
         (self.output / "labels.json").write_text(
             json.dumps(
@@ -348,6 +386,15 @@ class ArkindexExtractor:
                 indent=4,
             )
         )
+        (self.output / "language_corpus.txt").write_text(
+            "\n".join(self.language_corpus)
+        )
+        (self.output / "language_tokens.txt").write_text(
+            "\n".join(self.language_tokens)
+        )
+        (self.output / "language_lexicon.txt").write_text(
+            "\n".join(self.language_lexicon)
+        )
         (self.output / "charset.pkl").write_bytes(
             pickle.dumps(sorted(list(self.charset)))
         )
@@ -408,6 +455,7 @@ class ArkindexExtractor:
                     pbar.refresh()
 
         self.download_images()
+        self.format_lm_files()
         self.export()
 
 
diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py
index e6fb5296..e178c6dd 100644
--- a/dan/datasets/extract/utils.py
+++ b/dan/datasets/extract/utils.py
@@ -21,7 +21,8 @@ logger = logging.getLogger(__name__)
 DOWNLOAD_TIMEOUT = (30, 60)
 
 # replace \t with regular space and consecutive spaces
-TRIM_REGEX = re.compile(r"\t?(?: +)")
+TRIM_SPACE_REGEX = re.compile(r"[\t| ]+")
+TRIM_RETURN_REGEX = re.compile(r"[\r|\n]+")
 
 
 def _retry_log(retry_state, *args, **kwargs):
@@ -80,11 +81,20 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -
         + (entity_type.end if entity_type else "")
     )
 
+def normalize_linebreaks(text: str) -> str:
+    """
+    Remove begin/ending linebreaks
+    Replace \r with regular linebreak and consecutive linebreaks
+    """
+    return TRIM_RETURN_REGEX.sub("\n", text.strip())
+
 
-def remove_spaces(text: str) -> str:
-    # remove begin/ending spaces
-    # replace \t with regular space and consecutive spaces
-    return TRIM_REGEX.sub(" ", text.strip())
+def normalize_spaces(text: str) -> str:
+    """
+    Remove begin/ending spaces
+    Replace \t with regular space and consecutive spaces
+    """
+    return TRIM_SPACE_REGEX.sub(" ", text.strip())
 
 
 def get_bbox(polygon: List[List[int]]) -> str:
@@ -96,3 +106,4 @@ def get_bbox(polygon: List[List[int]]) -> str:
     x, y = min(all_x), min(all_y)
     width, height = max(all_x) - x, max(all_y) - y
     return ",".join(list(map(str, [int(x), int(y), int(width), int(height)])))
+
diff --git a/dan/utils.py b/dan/utils.py
index f813723e..97135eba 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -16,6 +16,14 @@ class MLflowNotInstalled(Exception):
     """
 
 
+LM_MAPPING = {
+    " ": "⎵",
+    "\n": "↵",
+    "<ctc>": "◌",
+    "<unk>": "⁇",
+}
+
+
 class EntityType(NamedTuple):
     start: str
     end: str = ""
-- 
GitLab