Skip to content
Snippets Groups Projects

Add Language Model Decoder

Merged Solene Tarride requested to merge lm-decoder into main
All threads resolved!
3 files
+ 76
9
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -33,9 +33,10 @@ from dan.datasets.extract.utils import (
download_image,
get_bbox,
insert_token,
remove_spaces,
normalize_linebreaks,
normalize_spaces
)
from dan.utils import EntityType, parse_tokens
from dan.utils import LM_MAPPING, EntityType, parse_tokens
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
BoundingBox,
@@ -93,6 +94,9 @@ class ArkindexExtractor:
self.data: Dict = defaultdict(dict)
self.charset = set()
self.language_corpus = []
self.language_tokens = []
self.language_lexicon = []
# Image download tasks to process
self.tasks: List[Dict[str, str]] = []
@@ -254,7 +258,8 @@ class ArkindexExtractor:
def format_text(self, text: str, charset: Optional[set] = None):
if not self.keep_spaces:
text = remove_spaces(text)
text = normalize_spaces(text)
text = normalize_linebreaks(text)
# Replace unknown characters by the unknown token
if charset is not None:
@@ -265,9 +270,19 @@ class ArkindexExtractor:
for unknown_char in unknown_charset
}
)
return text.strip()
def format_text_language_model(self, text: str):
"""
Format text for the language model. Return the text tokenized at character-level.
"""
return " ".join(
[
LM_MAPPING[token] if token in LM_MAPPING else token
for token in list(text.strip())
]
)
def process_element(
self,
element: Element,
@@ -305,6 +320,8 @@ class ArkindexExtractor:
self.data[split][str(image_path)] = text
self.charset = self.charset.union(set(text))
if split == "train":
self.language_corpus.append(self.format_text_language_model(text))
def process_parent(
self,
@@ -340,6 +357,27 @@ class ArkindexExtractor:
except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}")
def format_lm_files(self) -> None:
"""
Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset.
"""
for token in sorted(list(self.charset)):
assert (
token not in LM_MAPPING.values()
), f"Special token {token} is reserved for language modeling."
self.language_tokens.append(
LM_MAPPING[token]
) if token in LM_MAPPING else self.language_tokens.append(token)
# Add the special blank token
self.language_tokens.append(LM_MAPPING["<ctc>"])
# Build lexicon
assert all(
[len(token) == 1 for token in self.language_lexicon]
), "Tokens should be single characters."
self.language_lexicon = [f"{token} {token}" for token in self.language_tokens]
def export(self):
(self.output / "labels.json").write_text(
json.dumps(
@@ -348,6 +386,15 @@ class ArkindexExtractor:
indent=4,
)
)
(self.output / "language_corpus.txt").write_text(
"\n".join(self.language_corpus)
)
(self.output / "language_tokens.txt").write_text(
"\n".join(self.language_tokens)
)
(self.output / "language_lexicon.txt").write_text(
"\n".join(self.language_lexicon)
)
(self.output / "charset.pkl").write_bytes(
pickle.dumps(sorted(list(self.charset)))
)
@@ -408,6 +455,7 @@ class ArkindexExtractor:
pbar.refresh()
self.download_images()
self.format_lm_files()
self.export()
Loading