Skip to content
Snippets Groups Projects
Commit 8be455b0 authored by Solene Tarride's avatar Solene Tarride
Browse files

Generate LM files during data extraction

parent 7f25f137
No related branches found
No related tags found
No related merge requests found
......@@ -33,9 +33,10 @@ from dan.datasets.extract.utils import (
download_image,
get_bbox,
insert_token,
remove_spaces,
normalize_linebreaks,
normalize_spaces
)
from dan.utils import EntityType, parse_tokens
from dan.utils import LM_MAPPING, EntityType, parse_tokens
from line_image_extractor.extractor import extract
from line_image_extractor.image_utils import (
BoundingBox,
......@@ -93,6 +94,9 @@ class ArkindexExtractor:
self.data: Dict = defaultdict(dict)
self.charset = set()
self.language_corpus = []
self.language_tokens = []
self.language_lexicon = []
# Image download tasks to process
self.tasks: List[Dict[str, str]] = []
......@@ -254,7 +258,8 @@ class ArkindexExtractor:
def format_text(self, text: str, charset: Optional[set] = None):
if not self.keep_spaces:
text = remove_spaces(text)
text = normalize_spaces(text)
text = normalize_linebreaks(text)
# Replace unknown characters by the unknown token
if charset is not None:
......@@ -265,9 +270,19 @@ class ArkindexExtractor:
for unknown_char in unknown_charset
}
)
return text.strip()
def format_text_language_model(self, text: str):
"""
Format text for the language model. Return the text tokenized at character-level.
"""
return " ".join(
[
LM_MAPPING[token] if token in LM_MAPPING else token
for token in list(text.strip())
]
)
def process_element(
self,
element: Element,
......@@ -305,6 +320,8 @@ class ArkindexExtractor:
self.data[split][str(image_path)] = text
self.charset = self.charset.union(set(text))
if split == "train":
self.language_corpus.append(self.format_text_language_model(text))
def process_parent(
self,
......@@ -340,6 +357,27 @@ class ArkindexExtractor:
except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}")
def format_lm_files(self) -> None:
"""
Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset.
"""
for token in sorted(list(self.charset)):
assert (
token not in LM_MAPPING.values()
), f"Special token {token} is reserved for language modeling."
self.language_tokens.append(
LM_MAPPING[token]
) if token in LM_MAPPING else self.language_tokens.append(token)
# Add the special blank token
self.language_tokens.append(LM_MAPPING["<ctc>"])
# Build lexicon
assert all(
[len(token) == 1 for token in self.language_lexicon]
), "Tokens should be single characters."
self.language_lexicon = [f"{token} {token}" for token in self.language_tokens]
def export(self):
(self.output / "labels.json").write_text(
json.dumps(
......@@ -348,6 +386,15 @@ class ArkindexExtractor:
indent=4,
)
)
(self.output / "language_corpus.txt").write_text(
"\n".join(self.language_corpus)
)
(self.output / "language_tokens.txt").write_text(
"\n".join(self.language_tokens)
)
(self.output / "language_lexicon.txt").write_text(
"\n".join(self.language_lexicon)
)
(self.output / "charset.pkl").write_bytes(
pickle.dumps(sorted(list(self.charset)))
)
......@@ -408,6 +455,7 @@ class ArkindexExtractor:
pbar.refresh()
self.download_images()
self.format_lm_files()
self.export()
......
......@@ -21,7 +21,8 @@ logger = logging.getLogger(__name__)
DOWNLOAD_TIMEOUT = (30, 60)
# replace \t with regular space and consecutive spaces
TRIM_REGEX = re.compile(r"\t?(?: +)")
TRIM_SPACE_REGEX = re.compile(r"[\t| ]+")
TRIM_RETURN_REGEX = re.compile(r"[\r|\n]+")
def _retry_log(retry_state, *args, **kwargs):
......@@ -80,11 +81,20 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -
+ (entity_type.end if entity_type else "")
)
def normalize_linebreaks(text: str) -> str:
"""
Remove begin/ending linebreaks
Replace \r with regular linebreak and consecutive linebreaks
"""
return TRIM_RETURN_REGEX.sub("\n", text.strip())
def remove_spaces(text: str) -> str:
# remove begin/ending spaces
# replace \t with regular space and consecutive spaces
return TRIM_REGEX.sub(" ", text.strip())
def normalize_spaces(text: str) -> str:
"""
Remove begin/ending spaces
Replace \t with regular space and consecutive spaces
"""
return TRIM_SPACE_REGEX.sub(" ", text.strip())
def get_bbox(polygon: List[List[int]]) -> str:
......@@ -96,3 +106,4 @@ def get_bbox(polygon: List[List[int]]) -> str:
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return ",".join(list(map(str, [int(x), int(y), int(width), int(height)])))
......@@ -16,6 +16,14 @@ class MLflowNotInstalled(Exception):
"""
LM_MAPPING = {
" ": "",
"\n": "",
"<ctc>": "",
"<unk>": "",
}
class EntityType(NamedTuple):
start: str
end: str = ""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment