From 9557004299534004b93faaa094b42fc30677e6cb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Fri, 13 Oct 2023 16:01:07 +0200
Subject: [PATCH] Add vocabulary size parameter for subword tokenizer

---
 dan/datasets/extract/__init__.py | 7 +++++++
 dan/datasets/extract/extract.py  | 6 +++++-
 dan/datasets/extract/utils.py    | 7 +------
 3 files changed, 13 insertions(+), 7 deletions(-)

diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py
index 80fc3de1..717eab5c 100644
--- a/dan/datasets/extract/__init__.py
+++ b/dan/datasets/extract/__init__.py
@@ -147,6 +147,13 @@ def add_extract_parser(subcommands) -> None:
         help="Images larger than this height will be resized to this width.",
     )
 
+    parser.add_argument(
+        "--subword-vocab-size",
+        type=int,
+        default=1000,
+        help="Size of the vocabulary to train the sentencepiece subword tokenizer needed for language model.",
+    )
+
     # Formatting arguments
     parser.add_argument(
         "--image-format",
diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py
index 2befb801..9bff74d9 100644
--- a/dan/datasets/extract/extract.py
+++ b/dan/datasets/extract/extract.py
@@ -78,6 +78,7 @@ class ArkindexExtractor:
         keep_spaces: bool = False,
         image_extension: str = "",
         allow_empty: bool = False,
+        subword_vocab_size: int = 1000,
     ) -> None:
         self.folders = folders
         self.element_type = element_type
@@ -93,8 +94,8 @@ class ArkindexExtractor:
         self.image_extension = image_extension
         self.allow_empty = allow_empty
         self.mapping = LMTokenMapping()
-
         self.keep_spaces = keep_spaces
+        self.subword_vocab_size = subword_vocab_size
 
         self.data: Dict = defaultdict(dict)
         self.charset = set()
@@ -375,6 +376,7 @@ class ArkindexExtractor:
             outdir=self.output / "language_model",
             mapping=self.mapping,
             tokens=self.tokens,
+            subword_vocab_size=self.subword_vocab_size,
         )
         self.language_corpus["characters"] = [
             tokenizer.char_tokenize(doc) for doc in train_corpus
@@ -518,6 +520,7 @@ def run(
     image_format: str,
     keep_spaces: bool,
     allow_empty: bool,
+    subword_vocab_size: int,
 ):
     assert database.exists(), f"No file found @ {database}"
     open_database(path=database)
@@ -544,4 +547,5 @@ def run(
         keep_spaces=keep_spaces,
         image_extension=image_format,
         allow_empty=allow_empty,
+        subword_vocab_size=subword_vocab_size,
     ).run()
diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py
index 2b93334c..7d5111d7 100644
--- a/dan/datasets/extract/utils.py
+++ b/dan/datasets/extract/utils.py
@@ -134,7 +134,7 @@ class Tokenizer:
         self.tokens = tokens
         self.mapping = mapping
         # Train the subword tokenizer
-        self.user_subword_vocab_size = subword_vocab_size
+        self.subword_vocab_size = subword_vocab_size
         self.sentencepiece_model = self.train_subword_tokenizer()
 
     @property
@@ -153,11 +153,6 @@ class Tokenizer:
     def special_tokens(self) -> List[str]:
         return list(set(self.ner_tokens + self.mapping_tokens))
 
-    @property
-    def subword_vocab_size(self):
-        n_words = len(set([word for doc in self.corpus for word in doc.split()]))
-        return min(self.user_subword_vocab_size, 3 * n_words)
-
     def train_subword_tokenizer(self):
         """
         Train a sentencepiece model on the training corpus.
-- 
GitLab