From e9c3ac5b8f1fb6ac218aceca661e26ca2e7cffea Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Fri, 15 Sep 2023 15:00:50 +0200
Subject: [PATCH] Generate LM files during data extraction

---
 dan/datasets/extract/extract.py | 27 +++++++++++++++++----------
 dan/datasets/extract/utils.py   | 16 +++++++---------
 dan/utils.py                    | 26 ++++++--------------------
 3 files changed, 30 insertions(+), 39 deletions(-)

diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py
index a7fa8d32..c4d71632 100644
--- a/dan/datasets/extract/extract.py
+++ b/dan/datasets/extract/extract.py
@@ -34,9 +34,9 @@ from dan.datasets.extract.utils import (
     get_bbox,
     insert_token,
     normalize_linebreaks,
-    normalize_spaces,
+    normalize_spaces
 )
-from dan.utils import EntityType, LMTokenMapping, parse_tokens
+from dan.utils import LM_MAPPING, EntityType, parse_tokens
 from line_image_extractor.extractor import extract
 from line_image_extractor.image_utils import (
     BoundingBox,
@@ -279,7 +279,12 @@ class ArkindexExtractor:
         """
         Format text for the language model. Return the text tokenized at character-level.
         """
-        return " ".join(map(self.mapping.encode_token, list(text.strip())))
+        return " ".join(
+            [
+                LM_MAPPING[token] if token in LM_MAPPING else token
+                for token in list(text.strip())
+            ]
+        )
 
     def process_element(
         self,
@@ -318,6 +323,8 @@ class ArkindexExtractor:
 
         self.data[split][str(image_path)] = text
         self.charset = self.charset.union(set(text))
+        if split == "train":
+            self.language_corpus.append(self.format_text_language_model(text))
 
         # Language model should be built using only text from the training set
         if split == "train":
@@ -363,14 +370,14 @@ class ArkindexExtractor:
         """
         for token in sorted(list(self.charset)):
             assert (
-                token not in self.mapping.encode.values()
+                token not in LM_MAPPING.values()
             ), f"Special token {token} is reserved for language modeling."
             self.language_tokens.append(
-                self.mapping.encode[token]
-            ) if token in self.mapping.encode else self.language_tokens.append(token)
+                LM_MAPPING[token]
+            ) if token in LM_MAPPING else self.language_tokens.append(token)
 
         # Add the special blank token
-        self.language_tokens.append(self.mapping.ctc.encoded)
+        self.language_tokens.append(LM_MAPPING["<ctc>"])
 
         # Build lexicon
         assert all(
@@ -386,13 +393,13 @@ class ArkindexExtractor:
                 indent=4,
             )
         )
-        (self.output / "language_model" / "corpus.txt").write_text(
+        (self.output / "language_corpus.txt").write_text(
             "\n".join(self.language_corpus)
         )
-        (self.output / "language_model" / "tokens.txt").write_text(
+        (self.output / "language_tokens.txt").write_text(
             "\n".join(self.language_tokens)
         )
-        (self.output / "language_model" / "lexicon.txt").write_text(
+        (self.output / "language_lexicon.txt").write_text(
             "\n".join(self.language_lexicon)
         )
         (self.output / "charset.pkl").write_bytes(
diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py
index 79e5ef6f..2863410d 100644
--- a/dan/datasets/extract/utils.py
+++ b/dan/datasets/extract/utils.py
@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
 DOWNLOAD_TIMEOUT = (30, 60)
 
 # replace \t with regular space and consecutive spaces
-TRIM_SPACE_REGEX = re.compile(r"[\t ]+")
-TRIM_RETURN_REGEX = re.compile(r"[\r\n]+")
+TRIM_SPACE_REGEX = re.compile(r"[\t| ]+")
+TRIM_RETURN_REGEX = re.compile(r"[\r|\n]+")
 
 
 def _retry_log(retry_state, *args, **kwargs):
@@ -89,21 +89,18 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -
         + (entity_type.end if entity_type else "")
     )
 
-
 def normalize_linebreaks(text: str) -> str:
     """
-    Remove begin/ending linebreaks.
-    Replace \r with regular linebreak and consecutive linebreaks.
-    :param text: Text to normalize.
+    Remove begin/ending linebreaks
+    Replace \r with regular linebreak and consecutive linebreaks
     """
     return TRIM_RETURN_REGEX.sub("\n", text.strip())
 
 
 def normalize_spaces(text: str) -> str:
     """
-    Remove begin/ending spaces.
-    Replace \t with regular space and consecutive spaces.
-    :param text: Text to normalize.
+    Remove begin/ending spaces
+    Replace \t with regular space and consecutive spaces
     """
     return TRIM_SPACE_REGEX.sub(" ", text.strip())
 
@@ -117,3 +114,4 @@ def get_bbox(polygon: List[List[int]]) -> str:
     x, y = min(all_x), min(all_y)
     width, height = max(all_x) - x, max(all_y) - y
     return ",".join(list(map(str, [int(x), int(y), int(width), int(height)])))
+
diff --git a/dan/utils.py b/dan/utils.py
index 69e7d82a..97135eba 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -16,26 +16,12 @@ class MLflowNotInstalled(Exception):
     """
 
 
-class Token(NamedTuple):
-    encoded: str
-    display: str
-
-
-class LMTokenMapping(NamedTuple):
-    space: Token = Token("⎵", " ")
-    linebreak: Token = Token("↵", "\n")
-    ctc: Token = Token("◌", "<ctc>")
-
-    @property
-    def display(self):
-        return {a.encoded: a.display for a in self}
-
-    @property
-    def encode(self):
-        return {a.display: a.encoded for a in self}
-
-    def encode_token(self, token: str) -> str:
-        return self.encode.get(token, token)
+LM_MAPPING = {
+    " ": "⎵",
+    "\n": "↵",
+    "<ctc>": "◌",
+    "<unk>": "⁇",
+}
 
 
 class EntityType(NamedTuple):
-- 
GitLab