Skip to content
Snippets Groups Projects

Support subword and word language models

Merged Solene Tarride requested to merge subword-and-word-lm into main
All threads resolved!
2 files
+ 107
23
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -14,6 +14,7 @@ import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
from nltk.tokenize import wordpunct_tokenize
from arkindex_export import open_database
from dan.datasets.extract.db import (
@@ -35,6 +36,7 @@ from dan.datasets.extract.utils import (
insert_token,
normalize_linebreaks,
normalize_spaces,
Tokenizer,
)
from dan.utils import EntityType, LMTokenMapping, parse_tokens
from line_image_extractor.extractor import extract
@@ -97,9 +99,9 @@ class ArkindexExtractor:
self.data: Dict = defaultdict(dict)
self.charset = set()
self.language_corpus = []
self.language_corpus = defaultdict(list)
self.language_tokens = []
self.language_lexicon = []
self.language_lexicon = defaultdict(list)
# Image download tasks to process
self.tasks: List[Dict[str, str]] = []
@@ -275,12 +277,6 @@ class ArkindexExtractor:
)
return text.strip()
def format_text_language_model(self, text: str):
"""
Format text for the language model. Return the text tokenized at character-level.
"""
return " ".join(map(self.mapping.encode_token, list(text.strip())))
def process_element(
self,
element: Element,
@@ -319,10 +315,6 @@ class ArkindexExtractor:
self.data[split][str(image_path)] = text
self.charset = self.charset.union(set(text))
# Language model should be built using only text from the training set
if split == "train":
self.language_corpus.append(self.format_text_language_model(text))
def process_parent(
self,
pbar,
@@ -361,6 +353,9 @@ class ArkindexExtractor:
"""
Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset.
"""
logger.info("Preparing language resources")
# Build LM tokens
for token in sorted(list(self.charset)):
assert (
token not in self.mapping.encode.values()
@@ -369,14 +364,27 @@ class ArkindexExtractor:
self.mapping.encode[token]
) if token in self.mapping.encode else self.language_tokens.append(token)
# Add the special blank token
self.language_tokens.append(self.mapping.ctc.encoded)
# Build lexicon
assert all(
[len(token) == 1 for token in self.language_lexicon]
), "Tokens should be single characters."
self.language_lexicon = [f"{token} {token}" for token in self.language_tokens]
# Build LM corpus
train_corpus = [text for text in self.data["train"].values()]
tokenizer = Tokenizer(train_corpus, outdir=self.output / "language_model", mapping=self.mapping, tokens=self.tokens)
tokenizer.train_subword_tokenizer()
self.language_corpus["characters"] = [tokenizer.char_tokenize(doc) for doc in train_corpus]
self.language_corpus["words"] = [tokenizer.word_tokenize(doc) for doc in train_corpus]
self.language_corpus["subwords"] = [tokenizer.subword_tokenize(doc) for doc in train_corpus]
# Build vocabulary
word_vocabulary = set([word for doc in self.language_corpus["words"] for word in doc.split(" ")])
subword_vocabulary = set([subword for doc in self.language_corpus["subwords"] for subword in doc.split(" ")])
# Build LM lexicon
self.language_lexicon["chars"] = [f"{token} {tokenizer.char_tokenize(token)}" for token in self.language_tokens]
self.language_lexicon["words"] = [f"{word} {tokenizer.char_tokenize(word)}" for word in word_vocabulary]
self.language_lexicon["subwords"] = [f"{subword} {tokenizer.char_tokenize(subword)}" for subword in subword_vocabulary]
def export(self):
(self.output / "labels.json").write_text(
@@ -386,15 +394,16 @@ class ArkindexExtractor:
indent=4,
)
)
(self.output / "language_model" / "corpus.txt").write_text(
"\n".join(self.language_corpus)
)
for level in ["characters", "words", "subwords"]:
(self.output / "language_model" / f"corpus_{level}.txt").write_text(
"\n".join(self.language_corpus[level])
)
(self.output / "language_model" / f"lexicon_{level}.txt").write_text(
"\n".join(self.language_lexicon[level])
)
(self.output / "language_model" / "tokens.txt").write_text(
"\n".join(self.language_tokens)
)
(self.output / "language_model" / "lexicon.txt").write_text(
"\n".join(self.language_lexicon)
)
(self.output / "charset.pkl").write_bytes(
pickle.dumps(sorted(list(self.charset)))
)
Loading