Skip to content
Snippets Groups Projects
Commit e9c3ac5b authored by Solene Tarride's avatar Solene Tarride Committed by Solene Tarride
Browse files

Generate LM files during data extraction

parent f88d0b28
No related branches found
No related tags found
No related merge requests found
...@@ -34,9 +34,9 @@ from dan.datasets.extract.utils import ( ...@@ -34,9 +34,9 @@ from dan.datasets.extract.utils import (
get_bbox, get_bbox,
insert_token, insert_token,
normalize_linebreaks, normalize_linebreaks,
normalize_spaces, normalize_spaces
) )
from dan.utils import EntityType, LMTokenMapping, parse_tokens from dan.utils import LM_MAPPING, EntityType, parse_tokens
from line_image_extractor.extractor import extract from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import ( from line_image_extractor.image_utils import (
BoundingBox, BoundingBox,
...@@ -279,7 +279,12 @@ class ArkindexExtractor: ...@@ -279,7 +279,12 @@ class ArkindexExtractor:
""" """
Format text for the language model. Return the text tokenized at character-level. Format text for the language model. Return the text tokenized at character-level.
""" """
return " ".join(map(self.mapping.encode_token, list(text.strip()))) return " ".join(
[
LM_MAPPING[token] if token in LM_MAPPING else token
for token in list(text.strip())
]
)
def process_element( def process_element(
self, self,
...@@ -318,6 +323,8 @@ class ArkindexExtractor: ...@@ -318,6 +323,8 @@ class ArkindexExtractor:
self.data[split][str(image_path)] = text self.data[split][str(image_path)] = text
self.charset = self.charset.union(set(text)) self.charset = self.charset.union(set(text))
if split == "train":
self.language_corpus.append(self.format_text_language_model(text))
# Language model should be built using only text from the training set # Language model should be built using only text from the training set
if split == "train": if split == "train":
...@@ -363,14 +370,14 @@ class ArkindexExtractor: ...@@ -363,14 +370,14 @@ class ArkindexExtractor:
""" """
for token in sorted(list(self.charset)): for token in sorted(list(self.charset)):
assert ( assert (
token not in self.mapping.encode.values() token not in LM_MAPPING.values()
), f"Special token {token} is reserved for language modeling." ), f"Special token {token} is reserved for language modeling."
self.language_tokens.append( self.language_tokens.append(
self.mapping.encode[token] LM_MAPPING[token]
) if token in self.mapping.encode else self.language_tokens.append(token) ) if token in LM_MAPPING else self.language_tokens.append(token)
# Add the special blank token # Add the special blank token
self.language_tokens.append(self.mapping.ctc.encoded) self.language_tokens.append(LM_MAPPING["<ctc>"])
# Build lexicon # Build lexicon
assert all( assert all(
...@@ -386,13 +393,13 @@ class ArkindexExtractor: ...@@ -386,13 +393,13 @@ class ArkindexExtractor:
indent=4, indent=4,
) )
) )
(self.output / "language_model" / "corpus.txt").write_text( (self.output / "language_corpus.txt").write_text(
"\n".join(self.language_corpus) "\n".join(self.language_corpus)
) )
(self.output / "language_model" / "tokens.txt").write_text( (self.output / "language_tokens.txt").write_text(
"\n".join(self.language_tokens) "\n".join(self.language_tokens)
) )
(self.output / "language_model" / "lexicon.txt").write_text( (self.output / "language_lexicon.txt").write_text(
"\n".join(self.language_lexicon) "\n".join(self.language_lexicon)
) )
(self.output / "charset.pkl").write_bytes( (self.output / "charset.pkl").write_bytes(
......
...@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__) ...@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
DOWNLOAD_TIMEOUT = (30, 60) DOWNLOAD_TIMEOUT = (30, 60)
# replace \t with regular space and consecutive spaces # replace \t with regular space and consecutive spaces
TRIM_SPACE_REGEX = re.compile(r"[\t ]+") TRIM_SPACE_REGEX = re.compile(r"[\t| ]+")
TRIM_RETURN_REGEX = re.compile(r"[\r\n]+") TRIM_RETURN_REGEX = re.compile(r"[\r|\n]+")
def _retry_log(retry_state, *args, **kwargs): def _retry_log(retry_state, *args, **kwargs):
...@@ -89,21 +89,18 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) - ...@@ -89,21 +89,18 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -
+ (entity_type.end if entity_type else "") + (entity_type.end if entity_type else "")
) )
def normalize_linebreaks(text: str) -> str: def normalize_linebreaks(text: str) -> str:
""" """
Remove begin/ending linebreaks. Remove begin/ending linebreaks
Replace \r with regular linebreak and consecutive linebreaks. Replace \r with regular linebreak and consecutive linebreaks
:param text: Text to normalize.
""" """
return TRIM_RETURN_REGEX.sub("\n", text.strip()) return TRIM_RETURN_REGEX.sub("\n", text.strip())
def normalize_spaces(text: str) -> str: def normalize_spaces(text: str) -> str:
""" """
Remove begin/ending spaces. Remove begin/ending spaces
Replace \t with regular space and consecutive spaces. Replace \t with regular space and consecutive spaces
:param text: Text to normalize.
""" """
return TRIM_SPACE_REGEX.sub(" ", text.strip()) return TRIM_SPACE_REGEX.sub(" ", text.strip())
...@@ -117,3 +114,4 @@ def get_bbox(polygon: List[List[int]]) -> str: ...@@ -117,3 +114,4 @@ def get_bbox(polygon: List[List[int]]) -> str:
x, y = min(all_x), min(all_y) x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y width, height = max(all_x) - x, max(all_y) - y
return ",".join(list(map(str, [int(x), int(y), int(width), int(height)]))) return ",".join(list(map(str, [int(x), int(y), int(width), int(height)])))
...@@ -16,26 +16,12 @@ class MLflowNotInstalled(Exception): ...@@ -16,26 +16,12 @@ class MLflowNotInstalled(Exception):
""" """
class Token(NamedTuple): LM_MAPPING = {
encoded: str " ": "",
display: str "\n": "",
"<ctc>": "",
"<unk>": "",
class LMTokenMapping(NamedTuple): }
space: Token = Token("", " ")
linebreak: Token = Token("", "\n")
ctc: Token = Token("", "<ctc>")
@property
def display(self):
return {a.encoded: a.display for a in self}
@property
def encode(self):
return {a.display: a.encoded for a in self}
def encode_token(self, token: str) -> str:
return self.encode.get(token, token)
class EntityType(NamedTuple): class EntityType(NamedTuple):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment