Skip to content
Snippets Groups Projects
Commit e9c3ac5b authored by Solene Tarride's avatar Solene Tarride Committed by Solene Tarride
Browse files

Generate LM files during data extraction

parent f88d0b28
No related branches found
No related tags found
No related merge requests found
......@@ -34,9 +34,9 @@ from dan.datasets.extract.utils import (
get_bbox,
insert_token,
normalize_linebreaks,
normalize_spaces,
normalize_spaces
)
from dan.utils import EntityType, LMTokenMapping, parse_tokens
from dan.utils import LM_MAPPING, EntityType, parse_tokens
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
BoundingBox,
......@@ -279,7 +279,12 @@ class ArkindexExtractor:
"""
Format text for the language model. Return the text tokenized at character-level.
"""
return " ".join(map(self.mapping.encode_token, list(text.strip())))
return " ".join(
[
LM_MAPPING[token] if token in LM_MAPPING else token
for token in list(text.strip())
]
)
def process_element(
self,
......@@ -318,6 +323,8 @@ class ArkindexExtractor:
self.data[split][str(image_path)] = text
self.charset = self.charset.union(set(text))
if split == "train":
self.language_corpus.append(self.format_text_language_model(text))
# Language model should be built using only text from the training set
if split == "train":
......@@ -363,14 +370,14 @@ class ArkindexExtractor:
"""
for token in sorted(list(self.charset)):
assert (
token not in self.mapping.encode.values()
token not in LM_MAPPING.values()
), f"Special token {token} is reserved for language modeling."
self.language_tokens.append(
self.mapping.encode[token]
) if token in self.mapping.encode else self.language_tokens.append(token)
LM_MAPPING[token]
) if token in LM_MAPPING else self.language_tokens.append(token)
# Add the special blank token
self.language_tokens.append(self.mapping.ctc.encoded)
self.language_tokens.append(LM_MAPPING["<ctc>"])
# Build lexicon
assert all(
......@@ -386,13 +393,13 @@ class ArkindexExtractor:
indent=4,
)
)
(self.output / "language_model" / "corpus.txt").write_text(
(self.output / "language_corpus.txt").write_text(
"\n".join(self.language_corpus)
)
(self.output / "language_model" / "tokens.txt").write_text(
(self.output / "language_tokens.txt").write_text(
"\n".join(self.language_tokens)
)
(self.output / "language_model" / "lexicon.txt").write_text(
(self.output / "language_lexicon.txt").write_text(
"\n".join(self.language_lexicon)
)
(self.output / "charset.pkl").write_bytes(
......
......@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
DOWNLOAD_TIMEOUT = (30, 60)
# replace \t with regular space and consecutive spaces
TRIM_SPACE_REGEX = re.compile(r"[\t ]+")
TRIM_RETURN_REGEX = re.compile(r"[\r\n]+")
TRIM_SPACE_REGEX = re.compile(r"[\t| ]+")
TRIM_RETURN_REGEX = re.compile(r"[\r|\n]+")
def _retry_log(retry_state, *args, **kwargs):
......@@ -89,21 +89,18 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -
+ (entity_type.end if entity_type else "")
)
def normalize_linebreaks(text: str) -> str:
"""
Remove begin/ending linebreaks.
Replace \r with regular linebreak and consecutive linebreaks.
:param text: Text to normalize.
Remove begin/ending linebreaks
Replace \r with regular linebreak and consecutive linebreaks
"""
return TRIM_RETURN_REGEX.sub("\n", text.strip())
def normalize_spaces(text: str) -> str:
"""
Remove begin/ending spaces.
Replace \t with regular space and consecutive spaces.
:param text: Text to normalize.
Remove begin/ending spaces
Replace \t with regular space and consecutive spaces
"""
return TRIM_SPACE_REGEX.sub(" ", text.strip())
......@@ -117,3 +114,4 @@ def get_bbox(polygon: List[List[int]]) -> str:
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return ",".join(list(map(str, [int(x), int(y), int(width), int(height)])))
......@@ -16,26 +16,12 @@ class MLflowNotInstalled(Exception):
"""
class Token(NamedTuple):
encoded: str
display: str
class LMTokenMapping(NamedTuple):
space: Token = Token("", " ")
linebreak: Token = Token("", "\n")
ctc: Token = Token("", "<ctc>")
@property
def display(self):
return {a.encoded: a.display for a in self}
@property
def encode(self):
return {a.display: a.encoded for a in self}
def encode_token(self, token: str) -> str:
return self.encode.get(token, token)
LM_MAPPING = {
" ": "",
"\n": "",
"<ctc>": "",
"<unk>": "",
}
class EntityType(NamedTuple):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment