From 23cd8204db1dafa1f67a883da00a9b2e5c42e039 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Wed, 18 Oct 2023 08:32:34 +0200
Subject: [PATCH] Improve code

---
 dan/datasets/extract/extract.py | 32 ++++++-------
 dan/datasets/extract/utils.py   | 80 +++++++++++++++++----------------
 tests/test_extract.py           |  5 +--
 3 files changed, 56 insertions(+), 61 deletions(-)

diff --git a/dan/datasets/extract/extract.py b/dan/datasets/extract/extract.py
index 7478f019..6b13c5d8 100644
--- a/dan/datasets/extract/extract.py
+++ b/dan/datasets/extract/extract.py
@@ -372,35 +372,29 @@ class ArkindexExtractor:
             for text in self.data["train"].values()
         ]
         tokenizer = Tokenizer(
-            train_corpus,
+            training_corpus=train_corpus,
             outdir=self.output / "language_model",
             mapping=self.mapping,
             tokens=self.tokens,
             subword_vocab_size=self.subword_vocab_size,
         )
-        self.language_corpus["characters"] = [
-            tokenizer.char_tokenize(doc) for doc in train_corpus
-        ]
-        self.language_corpus["words"] = [
-            tokenizer.word_tokenize(doc) for doc in train_corpus
-        ]
-        self.language_corpus["subwords"] = [
-            tokenizer.subword_tokenize(doc) for doc in train_corpus
-        ]
+
+        for level, tokenize in (
+            ("characters", tokenizer.char_tokenize),
+            ("words", tokenizer.word_tokenize),
+            ("subwords", tokenizer.subword_tokenize),
+        ):
+            self.language_corpus[level] = list(map(tokenize, train_corpus))
 
         # Build LM lexicon
         self.language_lexicon["characters"] = [
             f"{token} {token}" for token in self.language_tokens
         ]
-        self.language_lexicon["words"] = [
-            f"{word} {tokenizer.char_tokenize(word)}"
-            for word in get_vocabulary(self.language_corpus["words"])
-            if word != ""
-        ]
-        self.language_lexicon["subwords"] = [
-            f"{subword} {tokenizer.char_tokenize(subword)}"
-            for subword in get_vocabulary(self.language_corpus["subwords"])
-        ]
+        for level in ["words", "subwords"]:
+            self.language_lexicon[level] = [
+                f"{token} {tokenizer.char_tokenize(token)}"
+                for token in get_vocabulary(self.language_corpus[level])
+            ]
 
     def export(self):
         (self.output / "labels.json").write_text(
diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py
index 092257ee..f0371b6a 100644
--- a/dan/datasets/extract/utils.py
+++ b/dan/datasets/extract/utils.py
@@ -1,9 +1,13 @@
 # -*- coding: utf-8 -*-
+import itertools
 import logging
+import operator
 import re
+from dataclasses import dataclass, field
 from io import BytesIO
 from pathlib import Path
-from typing import List
+from tempfile import NamedTemporaryFile
+from typing import Iterator, List, Optional, Union
 
 import requests
 import sentencepiece as spm
@@ -132,6 +136,7 @@ def get_vocabulary(tokenized_text: List[str]) -> set[str]:
     )
 
 
+@dataclass
 class Tokenizer:
     """
     A multi-level tokenizer (char, subword, word), where the subword tokenizer is trained using sentencepiece.
@@ -142,30 +147,27 @@ class Tokenizer:
     :param subword_vocab_size: Size of the vocabulary size to use to train the subword tokenizer.
     """
 
-    def __init__(
-        self,
-        training_corpus: List[str],
-        outdir: Path,
-        mapping: LMTokenMapping,
-        tokens: EntityType = None,
-        subword_vocab_size: int = 1000,
-    ) -> None:
-        self.corpus = training_corpus
-        self.outdir = outdir
-        self.prefix = f"{self.outdir}/subword_tokenizer"
-        self.tokens = tokens
-        self.mapping = mapping
-        # Train the subword tokenizer
-        self.subword_vocab_size = subword_vocab_size
-        self.sentencepiece_model = self.train_subword_tokenizer()
+    training_corpus: List[str]
+    outdir: Path
+    mapping: LMTokenMapping
+    tokens: Optional[EntityType] = None
+    subword_vocab_size: int = 1000
+    sentencepiece_model: spm.SentencePieceProcessor = field(init=False)
 
     @property
-    def ner_tokens(self) -> List[str]:
+    def prefix(self):
+        return self.outdir / "subword_tokenizer"
+
+    @property
+    def ner_tokens(self) -> Union[List[str], Iterator[str]]:
         if self.tokens is None:
             return []
-        return [entity.start for entity in self.tokens.values()] + [
-            entity.end for entity in self.tokens.values() if entity.end != ""
-        ]
+        return itertools.chain(
+            map(operator.attrgetter("start"), self.tokens.values()),
+            filter(
+                operator.truth, map(operator.attrgetter("end"), self.tokens.values())
+            ),
+        )
 
     @property
     def mapping_tokens(self) -> List[str]:
@@ -173,42 +175,42 @@ class Tokenizer:
 
     @property
     def special_tokens(self) -> List[str]:
-        return list(set(self.ner_tokens + self.mapping_tokens))
+        return list(set(itertools.chain(self.mapping_tokens, self.ner_tokens)))
 
-    def train_subword_tokenizer(self):
+    def __post_init__(self) -> None:
         """
         Train a sentencepiece model on the training corpus.
         """
         # Write the corpus in a text file
-        corpus_file = Path(self.outdir / "tmp.txt")
-        corpus_file.write_text("\n".join(self.corpus))
-
-        # Train the tokenizer
-        logger.info("Training sentencepiece model for subword tokenization")
-        spm.SentencePieceTrainer.train(
-            input=str(corpus_file),
-            vocab_size=self.subword_vocab_size,
-            model_prefix=self.prefix,
-            user_defined_symbols=self.special_tokens,
+        logger.info("Training a sentencepiece model for subword tokenization")
+        with NamedTemporaryFile(dir=self.outdir, suffix=".txt", mode="w") as tmp:
+            tmp.write("\n".join(self.training_corpus))
+            tmp.flush()
+            spm.SentencePieceTrainer.train(
+                input=tmp.name,
+                vocab_size=self.subword_vocab_size,
+                model_prefix=self.prefix,
+                user_defined_symbols=self.special_tokens,
+            )
+
+        # Load the model
+        self.sentencepiece_model = spm.SentencePieceProcessor(
+            model_file=str(self.prefix.with_suffix(".model"))
         )
 
-        # Delete the corpus file and load the model
-        corpus_file.unlink()
-        return spm.SentencePieceProcessor(model_file=f"{self.prefix}.model")
-
     def subword_tokenize(self, text: str) -> str:
         """
         Tokenize into subwords. Sampling is disabled to ensure reproducibility.
         """
         tokens = self.sentencepiece_model.encode(text, out_type=str)
-        return " ".join(["".join(self.encode(subword)) for subword in tokens])
+        return " ".join(map("".join, map(self.encode, tokens)))
 
     def word_tokenize(self, text: str) -> str:
         """
         Tokenize text into a string of space-separated words. Spaces (⎵) and NER tokens are considered as words.
         :param text: Text to be tokenized.
         """
-        words = ["".join(self.encode(word)) for word in wordpunct_tokenize(text)]
+        words = list(map("".join, map(self.encode, wordpunct_tokenize(text))))
         return " ".join(
             [
                 word + f" {self.mapping.space.encoded}"
diff --git a/tests/test_extract.py b/tests/test_extract.py
index 83fd3946..120cd788 100644
--- a/tests/test_extract.py
+++ b/tests/test_extract.py
@@ -644,15 +644,14 @@ def test_extract(
     assert (
         output / "language_model" / "corpus_words.txt"
     ).read_text() == expected_word_language_corpus
-    print((output / "language_model" / "corpus_subwords.txt").read_text())
-    print(expected_subword_language_corpus)
+
     assert (
         output / "language_model" / "corpus_subwords.txt"
     ).read_text() == expected_subword_language_corpus
 
     # Check "language_tokens.txt"
     expected_language_tokens = [
-        t if t != " " else "▁" for t in sorted(list(expected_charset))
+        "▁" if t.isspace() else t for t in sorted(list(expected_charset))
     ]
     expected_language_tokens.append("◌")
     assert (output / "language_model" / "tokens.txt").read_text() == "\n".join(
-- 
GitLab