Skip to content
Snippets Groups Projects
Commit bcfaa598 authored by Solene Tarride's avatar Solene Tarride Committed by Solene Tarride
Browse files

Generate LM files during data extraction

parent 6e78a28d
No related branches found
No related tags found
No related merge requests found
......@@ -27,9 +27,10 @@ from dan.datasets.extract.exceptions import (
from dan.datasets.extract.utils import (
download_image,
insert_token,
remove_spaces,
normalize_linebreaks,
normalize_spaces
)
from dan.utils import EntityType, parse_tokens
from dan.utils import LM_MAPPING, EntityType, parse_tokens
from line_image_extractor.extractor import extract, read_img, save_img
from line_image_extractor.image_utils import Extraction, polygon_to_bbox, resize
......@@ -83,6 +84,9 @@ class ArkindexExtractor:
self.data: Dict = defaultdict(dict)
self.charset = set()
self.language_corpus = []
self.language_tokens = []
self.language_lexicon = []
def find_image_in_cache(self, image_id: str) -> Path:
"""Images are cached to avoid downloading them twice. They are stored under a specific name,
......@@ -223,10 +227,25 @@ class ArkindexExtractor:
save_img(path=destination, img=image)
def format_text(self, text: str):
"""
Strip text and remove duplicate spaces and linebreaks if needed.
"""
if not self.keep_spaces:
text = remove_spaces(text)
text = normalize_spaces(text)
text = normalize_linebreaks(text)
return text.strip()
def format_text_language_model(self, text: str):
"""
Format text for the language model. Return the text tokenized at character-level.
"""
return " ".join(
[
LM_MAPPING[token] if token in LM_MAPPING else token
for token in list(text.strip())
]
)
def process_element(
self,
element: Element,
......@@ -243,8 +262,11 @@ class ArkindexExtractor:
).with_suffix(self.image_extension)
self.get_image(element, image_path)
self.data[split][str(image_path)] = self.format_text(text)
self.charset = self.charset.union(set(text))
clean_text = self.format_text(text)
self.data[split][str(image_path)] = clean_text
self.charset = self.charset.union(set(clean_text))
if split == "train":
self.language_corpus.append(self.format_text_language_model(clean_text))
def process_parent(
self,
......@@ -280,6 +302,27 @@ class ArkindexExtractor:
except ProcessingError as e:
logger.warning(f"Skipping {element.id}: {str(e)}")
def format_lm_files(self) -> None:
"""
Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset.
"""
for token in sorted(list(self.charset)):
assert (
token not in LM_MAPPING.values()
), f"Special token {token} is reserved for language modeling."
self.language_tokens.append(
LM_MAPPING[token]
) if token in LM_MAPPING else self.language_tokens.append(token)
# Add the special blank token
self.language_tokens.append(LM_MAPPING["<ctc>"])
# Build lexicon
assert all(
[len(token) == 1 for token in self.language_lexicon]
), "Tokens should be single characters."
self.language_lexicon = [f"{token} {token}" for token in self.language_tokens]
def export(self):
(self.output / "labels.json").write_text(
json.dumps(
......@@ -288,6 +331,15 @@ class ArkindexExtractor:
indent=4,
)
)
(self.output / "language_corpus.txt").write_text(
"\n".join(self.language_corpus)
)
(self.output / "language_tokens.txt").write_text(
"\n".join(self.language_tokens)
)
(self.output / "language_lexicon.txt").write_text(
"\n".join(self.language_lexicon)
)
(self.output / "charset.pkl").write_bytes(
pickle.dumps(sorted(list(self.charset)))
)
......@@ -312,6 +364,7 @@ class ArkindexExtractor:
# Progress bar updates
pbar.update()
pbar.refresh()
self.format_lm_files()
self.export()
......
......@@ -20,7 +20,8 @@ logger = logging.getLogger(__name__)
DOWNLOAD_TIMEOUT = (30, 60)
# replace \t with regular space and consecutive spaces
TRIM_REGEX = re.compile(r"\t?(?: +)")
TRIM_SPACE_REGEX = re.compile(r"[\t| ]+")
TRIM_RETURN_REGEX = re.compile(r"[\r|\n]+")
def _retry_log(retry_state, *args, **kwargs):
......@@ -76,7 +77,17 @@ def insert_token(text: str, entity_type: EntityType, offset: int, length: int) -
)
def remove_spaces(text: str) -> str:
# remove begin/ending spaces
# replace \t with regular space and consecutive spaces
return TRIM_REGEX.sub(" ", text.strip())
def normalize_linebreaks(text: str) -> str:
"""
Remove begin/ending linebreaks
Replace \r with regular linebreak and consecutive linebreaks
"""
return TRIM_RETURN_REGEX.sub("\n", text.strip())
def normalize_spaces(text: str) -> str:
"""
Remove begin/ending spaces
Replace \t with regular space and consecutive spaces
"""
return TRIM_SPACE_REGEX.sub(" ", text.strip())
......@@ -16,6 +16,14 @@ class MLflowNotInstalled(Exception):
"""
LM_MAPPING = {
" ": "",
"\n": "",
"<ctc>": "",
"<unk>": "",
}
class EntityType(NamedTuple):
start: str
end: str = ""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment