From 0f64242be2891bd334cdbc43c211c349f608c8f7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Thu, 19 Oct 2023 16:15:35 +0200
Subject: [PATCH] Map unknown characters

---
 dan/datasets/extract/arkindex.py |  3 +++
 dan/datasets/extract/utils.py    | 12 ++++++++----
 2 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py
index 6b13c5d8..88975536 100644
--- a/dan/datasets/extract/arkindex.py
+++ b/dan/datasets/extract/arkindex.py
@@ -371,8 +371,11 @@ class ArkindexExtractor:
             text.replace(self.mapping.linebreak.display, self.mapping.space.display)
             for text in self.data["train"].values()
         ]
+
         tokenizer = Tokenizer(
             training_corpus=train_corpus,
+            charset=self.language_tokens,
+            unknown_token=self.unknown_token,
             outdir=self.output / "language_model",
             mapping=self.mapping,
             tokens=self.tokens,
diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py
index f0371b6a..60e597ec 100644
--- a/dan/datasets/extract/utils.py
+++ b/dan/datasets/extract/utils.py
@@ -131,9 +131,7 @@ def get_vocabulary(tokenized_text: List[str]) -> set[str]:
     Compute set of vocabulary from tokenzied text.
     :param tokenized_text: List of tokenized text.
     """
-    return sorted(
-        set([token for doc in tokenized_text for token in doc.split() if token != ""])
-    )
+    return sorted(set([token for doc in tokenized_text for token in doc.split()]))
 
 
 @dataclass
@@ -148,6 +146,8 @@ class Tokenizer:
     """
 
     training_corpus: List[str]
+    charset: List[str]
+    unknown_token: str
     outdir: Path
     mapping: LMTokenMapping
     tokens: Optional[EntityType] = None
@@ -225,7 +225,11 @@ class Tokenizer:
         Tokenize text into a string of space-separated characters.
         :param text: Text to be tokenized.
         """
-        return " ".join(self.encode(list(text)))
+        return " ".join(
+            self.encode(
+                [char if char in self.charset else self.unknown_token for char in text]
+            )
+        )
 
     def encode(self, text: List[str]) -> List[str]:
         """
-- 
GitLab