Skip to content
Snippets Groups Projects

Support subword and word language models

Merged Solene Tarride requested to merge subword-and-word-lm into main
All threads resolved!
2 files
+ 11
4
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -30,8 +30,10 @@ from dan.datasets.extract.exceptions import (
UnknownTokenInText,
)
from dan.datasets.extract.utils import (
Tokenizer,
download_image,
get_bbox,
get_vocabulary,
insert_token,
normalize_linebreaks,
normalize_spaces,
@@ -77,6 +79,7 @@ class ArkindexExtractor:
keep_spaces: bool = False,
image_extension: str = "",
allow_empty: bool = False,
subword_vocab_size: int = 1000,
) -> None:
self.folders = folders
self.element_type = element_type
@@ -92,14 +95,14 @@ class ArkindexExtractor:
self.image_extension = image_extension
self.allow_empty = allow_empty
self.mapping = LMTokenMapping()
self.keep_spaces = keep_spaces
self.subword_vocab_size = subword_vocab_size
self.data: Dict = defaultdict(dict)
self.charset = set()
self.language_corpus = []
self.language_corpus = defaultdict(list)
self.language_tokens = []
self.language_lexicon = []
self.language_lexicon = defaultdict(list)
# Image download tasks to process
self.tasks: List[Dict[str, str]] = []
@@ -275,12 +278,6 @@ class ArkindexExtractor:
)
return text.strip()
def format_text_language_model(self, text: str):
"""
Format text for the language model. Return the text tokenized at character-level.
"""
return " ".join(map(self.mapping.encode_token, list(text.strip())))
def process_element(
self,
element: Element,
@@ -319,10 +316,6 @@ class ArkindexExtractor:
self.data[split][str(image_path)] = text
self.charset = self.charset.union(set(text))
# Language model should be built using only text from the training set
if split == "train":
self.language_corpus.append(self.format_text_language_model(text))
def process_parent(
self,
pbar,
@@ -361,6 +354,11 @@ class ArkindexExtractor:
"""
Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset.
"""
logger.info("Preparing language resources")
# Add unknown token to charset
self.charset.add(self.unknown_token)
# Build LM tokens
for token in sorted(list(self.charset)):
assert (
token not in self.mapping.encode.values()
@@ -368,15 +366,40 @@ class ArkindexExtractor:
self.language_tokens.append(
self.mapping.encode[token]
) if token in self.mapping.encode else self.language_tokens.append(token)
# Add the special blank token
self.language_tokens.append(self.mapping.ctc.encoded)
# Build lexicon
assert all(
[len(token) == 1 for token in self.language_lexicon]
), "Tokens should be single characters."
self.language_lexicon = [f"{token} {token}" for token in self.language_tokens]
# Build LM corpus
train_corpus = [
text.replace(self.mapping.linebreak.display, self.mapping.space.display)
for text in self.data["train"].values()
]
tokenizer = Tokenizer(
training_corpus=train_corpus,
charset=self.language_tokens,
unknown_token=self.unknown_token,
outdir=self.output / "language_model",
mapping=self.mapping,
tokens=self.tokens,
subword_vocab_size=self.subword_vocab_size,
)
for level, tokenize in (
("characters", tokenizer.char_tokenize),
("words", tokenizer.word_tokenize),
("subwords", tokenizer.subword_tokenize),
):
self.language_corpus[level] = list(map(tokenize, train_corpus))
# Build LM lexicon
self.language_lexicon["characters"] = [
f"{token} {token}" for token in self.language_tokens
]
for level in ["words", "subwords"]:
self.language_lexicon[level] = [
f"{token} {tokenizer.char_tokenize(token)}"
for token in get_vocabulary(self.language_corpus[level])
]
def export(self):
(self.output / "labels.json").write_text(
@@ -386,15 +409,16 @@ class ArkindexExtractor:
indent=4,
)
)
(self.output / "language_model" / "corpus.txt").write_text(
"\n".join(self.language_corpus)
)
for level in ["characters", "words", "subwords"]:
(self.output / "language_model" / f"corpus_{level}.txt").write_text(
"\n".join(self.language_corpus[level])
)
(self.output / "language_model" / f"lexicon_{level}.txt").write_text(
"\n".join(self.language_lexicon[level])
)
(self.output / "language_model" / "tokens.txt").write_text(
"\n".join(self.language_tokens)
)
(self.output / "language_model" / "lexicon.txt").write_text(
"\n".join(self.language_lexicon)
)
(self.output / "charset.pkl").write_bytes(
pickle.dumps(sorted(list(self.charset)))
)
@@ -477,6 +501,7 @@ def run(
image_format: str,
keep_spaces: bool,
allow_empty: bool,
subword_vocab_size: int,
):
assert database.exists(), f"No file found @ {database}"
open_database(path=database)
@@ -503,4 +528,5 @@ def run(
keep_spaces=keep_spaces,
image_extension=image_format,
allow_empty=allow_empty,
subword_vocab_size=subword_vocab_size,
).run()
Loading