From b215870164b1842c4286d8f84c9cc07b28da55da Mon Sep 17 00:00:00 2001 From: Gabriel Bermes Poupeau <gbermesp@teklia.com> Date: Wed, 17 Jul 2024 07:29:24 +0000 Subject: [PATCH] Move the unknown token replacement step to download --- .gitattributes | 1 - dan/datasets/download/__init__.py | 21 ++ dan/datasets/download/images.py | 129 ++++++++++- dan/datasets/download/utils.py | 136 +++++++++++ dan/datasets/extract/__init__.py | 14 +- dan/datasets/extract/arkindex.py | 137 +---------- dan/datasets/extract/exceptions.py | 9 - dan/datasets/extract/utils.py | 134 +---------- docs/usage/datasets/download.md | 43 +++- docs/usage/datasets/extract.md | 4 - tests/__init__.py | 53 ++++- tests/data/extraction/split.json | 18 +- tests/test_download.py | 351 +++++++++++++++++++++++++---- tests/test_extract.py | 344 ++-------------------------- 14 files changed, 727 insertions(+), 667 deletions(-) diff --git a/.gitattributes b/.gitattributes index b19024a2..7a3b5ee0 100644 --- a/.gitattributes +++ b/.gitattributes @@ -4,4 +4,3 @@ tests/data/prediction/language_model.arpa filter=lfs diff=lfs merge=lfs -text docs/assets/example_line.gif filter=lfs diff=lfs merge=lfs -text docs/assets/example_line_polygon.gif filter=lfs diff=lfs merge=lfs -text docs/assets/example_word.gif filter=lfs diff=lfs merge=lfs -text -docs/assets/example.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/dan/datasets/download/__init__.py b/dan/datasets/download/__init__.py index 00c06371..d32b29d2 100644 --- a/dan/datasets/download/__init__.py +++ b/dan/datasets/download/__init__.py @@ -53,4 +53,25 @@ def add_download_parser(subcommands) -> None: help="Images will be saved under this format.", ) + parser.add_argument( + "--unknown-token", + type=str, + default="â‡", + help="Token to use to replace character in the validation/test sets that is not included in the training set.", + ) + + parser.add_argument( + "--subword-vocab-size", + type=int, + help="Size of the vocabulary to train the sentencepiece subword tokenizer needed for language model.", + default=1000, + ) + + parser.add_argument( + "--tokens", + type=pathlib.Path, + help="Mapping between starting tokens and end tokens to extract text with their entities.", + required=False, + ) + parser.set_defaults(func=run) diff --git a/dan/datasets/download/images.py b/dan/datasets/download/images.py index 09d4fbe8..b6b93e92 100644 --- a/dan/datasets/download/images.py +++ b/dan/datasets/download/images.py @@ -5,8 +5,10 @@ import json import logging +import pickle from collections import defaultdict from concurrent.futures import Future, ThreadPoolExecutor +from itertools import chain from pathlib import Path from typing import Dict, List, Tuple @@ -16,7 +18,14 @@ from PIL import Image from tqdm import tqdm from dan.datasets.download.exceptions import ImageDownloadError -from dan.datasets.download.utils import download_image, get_bbox +from dan.datasets.download.utils import ( + Tokenizer, + download_image, + get_bbox, + get_vocabulary, +) +from dan.datasets.extract.arkindex import TRAIN_NAME +from dan.utils import LMTokenMapping, parse_tokens from line_image_extractor.extractor import extract from line_image_extractor.image_utils import ( BoundingBox, @@ -24,6 +33,8 @@ from line_image_extractor.image_utils import ( polygon_to_bbox, ) +LANGUAGE_DIR = "language_model" # Subpath to the language model directory. + IMAGES_DIR = "images" # Subpath to the images directory. IIIF_URL = "{image_url}/{bbox}/{size}/0/default.jpg" @@ -44,6 +55,9 @@ class ImageDownloader: max_width: int | None = None, max_height: int | None = None, image_extension: str = "", + unknown_token: str = "â‡", + subword_vocab_size: int = 1000, + tokens: Path | None = None, ) -> None: self.output = output @@ -51,6 +65,16 @@ class ImageDownloader: self.max_height = max_height self.image_extension = image_extension + self.unknown_token = unknown_token + self.tokens = parse_tokens(tokens) if tokens else {} + + self.subword_vocab_size = subword_vocab_size + self.mapping = LMTokenMapping() + + self.language_corpus = defaultdict(list) + self.language_tokens = [] + self.language_lexicon = defaultdict(list) + # Load split file split_file = self.output / "split.json" if self.output else None self.split: Dict = ( @@ -58,11 +82,17 @@ class ImageDownloader: if split_file and split_file.is_file() else {} ) + # Create directories for split_name in self.split: - Path(output, IMAGES_DIR, split_name).mkdir(parents=True, exist_ok=True) + (output / IMAGES_DIR / split_name).mkdir(parents=True, exist_ok=True) self.data: Dict = defaultdict(dict) + self.charset = set( + chain.from_iterable( + split_data["text"] for split_data in self.split[TRAIN_NAME].values() + ) + ) def check_extraction(self, values: dict) -> str | None: # Check dataset_id parameter @@ -83,6 +113,9 @@ class ImageDownloader: if values.get("text") is None: return "Text not found" + if self.unknown_token in values["text"]: + return "Unknown token found in the transcription text" + def get_iiif_size_arg(self, width: int, height: int) -> str: if (self.max_width is None or width <= self.max_width) and ( self.max_height is None or height <= self.max_height @@ -130,6 +163,16 @@ class ImageDownloader: image_path = destination / values["dataset_id"] / filename image_path.parent.mkdir(parents=True, exist_ok=True) + # Replace unknown characters by the unknown token + if split != TRAIN_NAME: + unknown_charset = set(values["text"]) - self.charset + values["text"] = values["text"].translate( + { + ord(unknown_char): self.unknown_token + for unknown_char in unknown_charset + } + ) + # Store a relative path to the label file in case we need to move the data elsewhere self.data[split][str(image_path.relative_to(self.output))] = values[ "text" @@ -230,6 +273,62 @@ class ImageDownloader: logger.error(f"Failed to download {len(failed_downloads)} image(s).") print(*list(map(": ".join, failed_downloads)), sep="\n") + def format_lm_files(self) -> None: + """ + Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset. + """ + logger.info("Preparing language resources") + # Add unknown token to charset + self.charset.add(self.unknown_token) + + # Build LM tokens + for token in sorted(list(self.charset)): + assert ( + token not in self.mapping.encode.values() + ), f"Special token {token} is reserved for language modeling." + self.language_tokens.append( + self.mapping.encode[token] + ) if token in self.mapping.encode else self.language_tokens.append(token) + self.language_tokens.append(self.mapping.ctc.encoded) + + # Build LM corpus + train_corpus = [ + values["text"].replace( + self.mapping.linebreak.display, self.mapping.space.display + ) + for values in self.split[TRAIN_NAME].values() + ] + + tokenizer = Tokenizer( + training_corpus=train_corpus, + charset=self.language_tokens, + unknown_token=self.unknown_token, + outdir=self.output / LANGUAGE_DIR, + mapping=self.mapping, + tokens=self.tokens, + subword_vocab_size=self.subword_vocab_size, + ) + + if not tokenizer.sentencepiece_model: + return + + for level, tokenize in ( + ("characters", tokenizer.char_tokenize), + ("words", tokenizer.word_tokenize), + ("subwords", tokenizer.subword_tokenize), + ): + self.language_corpus[level] = list(map(tokenize, train_corpus)) + + # Build LM lexicon + self.language_lexicon["characters"] = [ + f"{token} {token}" for token in self.language_tokens + ] + for level in ["words", "subwords"]: + self.language_lexicon[level] = [ + f"{token} {tokenizer.char_tokenize(token)}" + for token in get_vocabulary(self.language_corpus[level]) + ] + def export(self) -> None: """ Writes a `labels.json` file containing a mapping of the images that have been correctly uploaded (identified by its path) @@ -243,6 +342,20 @@ class ImageDownloader: ) ) + for level in ["characters", "words", "subwords"]: + (self.output / LANGUAGE_DIR / f"corpus_{level}.txt").write_text( + "\n".join(self.language_corpus[level]) + ) + (self.output / LANGUAGE_DIR / f"lexicon_{level}.txt").write_text( + "\n".join(self.language_lexicon[level]) + ) + (self.output / LANGUAGE_DIR / "tokens.txt").write_text( + "\n".join(self.language_tokens) + ) + (self.output / "charset.pkl").write_bytes( + pickle.dumps(sorted(list(self.charset))) + ) + def run(self) -> None: """ Download the missing images from a `split.json` file and build a `labels.json` file containing @@ -251,6 +364,7 @@ class ImageDownloader: """ tasks: List[Dict[str, str]] = self.build_tasks() self.download_images(tasks) + self.format_lm_files() self.export() @@ -259,6 +373,9 @@ def run( max_width: int | None, max_height: int | None, image_format: str, + unknown_token: str, + subword_vocab_size: int, + tokens: Path | None, ): """ Download the missing images from a `split.json` file and build a `labels.json` file containing @@ -269,10 +386,18 @@ def run( :param max_width: Images larger than this width will be resized to this width :param max_height: Images larger than this height will be resized to this height :param image_format: Images will be saved under this format + :param unknown_token: The token used to replace unknown characters. + :param subword_vocab_size: The size of the subword vocabulary. + :param tokens: Mapping between starting tokens and end tokens to extract text with their entities.. """ + (output / LANGUAGE_DIR).mkdir(parents=True, exist_ok=True) + ImageDownloader( output=output, max_width=max_width, max_height=max_height, image_extension=image_format, + unknown_token=unknown_token, + subword_vocab_size=subword_vocab_size, + tokens=tokens, ).run() diff --git a/dan/datasets/download/utils.py b/dan/datasets/download/utils.py index 8a7d4609..b80c442b 100644 --- a/dan/datasets/download/utils.py +++ b/dan/datasets/download/utils.py @@ -2,11 +2,18 @@ # This code is licensed under CeCILL-C # -*- coding: utf-8 -*- +import itertools import logging +import operator +from dataclasses import dataclass, field from io import BytesIO +from pathlib import Path +from tempfile import NamedTemporaryFile from typing import List import requests +import sentencepiece as spm +from nltk import wordpunct_tokenize from PIL import Image, ImageOps from tenacity import ( retry, @@ -15,6 +22,8 @@ from tenacity import ( wait_exponential, ) +from dan.utils import EntityType, LMTokenMapping + logger = logging.getLogger(__name__) # See http://docs.python-requests.org/en/master/user/advanced/#timeouts @@ -80,3 +89,130 @@ def get_bbox(polygon: List[List[int]]) -> str: x, y = min(all_x), min(all_y) width, height = max(all_x) - x, max(all_y) - y return ",".join(list(map(str, [int(x), int(y), int(width), int(height)]))) + + +def get_vocabulary(tokenized_text: List[str]) -> set[str]: + """ + Compute set of vocabulary from tokenzied text. + :param tokenized_text: List of tokenized text. + """ + return sorted(set([token for doc in tokenized_text for token in doc.split()])) + + +@dataclass +class Tokenizer: + """ + A multi-level tokenizer (char, subword, word), where the subword tokenizer is trained using sentencepiece. + :param training_corpus: List of training text. + :param outdir: Path to save the subword tokenizer. + :param mapping: Mapping between displayed and encoded versions of special characters. + :param tokens: Start and end tokens used to represent named entities. + :param subword_vocab_size: Size of the vocabulary size to use to train the subword tokenizer. + """ + + training_corpus: List[str] + charset: List[str] + unknown_token: str + outdir: Path + mapping: LMTokenMapping + tokens: EntityType | None = None + subword_vocab_size: int = 1000 + sentencepiece_model: spm.SentencePieceProcessor = field(init=False) + + @property + def prefix(self) -> Path: + return self.outdir / "subword_tokenizer" + + @property + def ner_tokens(self) -> List[str]: + if self.tokens is None: + return [] + return list( + itertools.chain( + map(operator.attrgetter("start"), self.tokens.values()), + filter( + operator.truth, + map(operator.attrgetter("end"), self.tokens.values()), + ), + ) + ) + + @property + def mapping_tokens(self) -> List[str]: + return [token.encoded for token in self.mapping] + + @property + def special_tokens(self) -> List[str]: + return list(set(itertools.chain(self.mapping_tokens, self.ner_tokens))) + + def __post_init__(self) -> None: + """ + Train a sentencepiece model on the training corpus. + """ + # Write the corpus in a text file + logger.info("Training a sentencepiece model for subword tokenization") + with NamedTemporaryFile(dir=self.outdir, suffix=".txt", mode="w") as tmp_file: + tmp_file.write("\n".join(self.training_corpus)) + tmp_file.flush() + + try: + spm.SentencePieceTrainer.train( + input=tmp_file.name, + vocab_size=self.subword_vocab_size, + model_prefix=self.prefix, + user_defined_symbols=self.special_tokens, + minloglevel=1, + ) + except Exception as e: + logger.warning( + f"Failed to train a sentencepiece model for subword tokenization: {e} " + "Try again by editing the `--subword-vocab-size` parameter." + ) + self.sentencepiece_model = None + return + + # Load the model + self.sentencepiece_model = spm.SentencePieceProcessor( + model_file=str(self.prefix.with_suffix(".model")) + ) + + def subword_tokenize(self, text: str) -> str: + """ + Tokenize into subwords. Sampling is disabled to ensure reproducibility. + """ + tokens = self.sentencepiece_model.encode(text, out_type=str) + return " ".join(map("".join, map(self.encode, tokens))) + + def word_tokenize(self, text: str) -> str: + """ + Tokenize text into a string of space-separated words. Spaces (⎵) and NER tokens are considered as words. + :param text: Text to be tokenized. + """ + words = list(map("".join, map(self.encode, wordpunct_tokenize(text)))) + return " ".join( + [ + f"{word} {self.mapping.space.encoded}" + if (i != len(words) - 1 and word not in self.ner_tokens) + else word + for i, word in enumerate(words) + ] + ) + + def char_tokenize(self, text: str) -> str: + """ + Tokenize text into a string of space-separated characters. + :param text: Text to be tokenized. + """ + return " ".join( + [ + char if char in self.charset else self.unknown_token + for char in self.encode(text) + ] + ) + + def encode(self, text: List[str]) -> List[str]: + """ + Encode special tokens. + :param text: Text to be encoded. + """ + return map(self.mapping.encode_token, text) diff --git a/dan/datasets/extract/__init__.py b/dan/datasets/extract/__init__.py index 87522fae..9b90c626 100644 --- a/dan/datasets/extract/__init__.py +++ b/dan/datasets/extract/__init__.py @@ -82,12 +82,7 @@ def add_extract_parser(subcommands) -> None: Do not give any arguments to keep the whole text. """, ) - parser.add_argument( - "--unknown-token", - type=str, - help="Token to use to replace character in the validation/test sets that is not included in the training set.", - default="â‡", - ) + parser.add_argument( "--tokens", type=pathlib.Path, @@ -124,13 +119,6 @@ def add_extract_parser(subcommands) -> None: default=[], ) - parser.add_argument( - "--subword-vocab-size", - type=int, - default=1000, - help="Size of the vocabulary to train the sentencepiece subword tokenizer needed for language model.", - ) - # Formatting arguments parser.add_argument( "--keep-spaces", diff --git a/dan/datasets/extract/arkindex.py b/dan/datasets/extract/arkindex.py index ba40bf90..e06310ab 100644 --- a/dan/datasets/extract/arkindex.py +++ b/dan/datasets/extract/arkindex.py @@ -5,7 +5,6 @@ import json import logging -import pickle import random from collections import defaultdict from pathlib import Path @@ -24,19 +23,14 @@ from dan.datasets.extract.db import ( from dan.datasets.extract.exceptions import ( NoTranscriptionError, ProcessingError, - UnknownTokenInText, ) from dan.datasets.extract.utils import ( - Tokenizer, entities_to_xml, get_translation_map, - get_vocabulary, normalize_linebreaks, normalize_spaces, ) -from dan.utils import LMTokenMapping, parse_tokens - -LANGUAGE_DIR = "language_model" # Subpath to the language model directory. +from dan.utils import parse_tokens TRAIN_NAME = "train" VAL_NAME = "val" @@ -57,7 +51,6 @@ class ArkindexExtractor: dataset_ids: List[UUID] | None = None, element_type: List[str] = [], entity_separators: List[str] = ["\n", " "], - unknown_token: str = "â‡", tokens: Path | None = None, transcription_worker_versions: List[str | bool] = [], entity_worker_versions: List[str | bool] = [], @@ -65,44 +58,25 @@ class ArkindexExtractor: entity_worker_runs: List[str | bool] = [], keep_spaces: bool = False, allow_empty: bool = False, - subword_vocab_size: int = 1000, ) -> None: self.dataset_ids = dataset_ids self.element_type = element_type self.output = output self.entity_separators = entity_separators - self.unknown_token = unknown_token self.tokens = parse_tokens(tokens) if tokens else {} self.transcription_worker_versions = transcription_worker_versions self.entity_worker_versions = entity_worker_versions self.transcription_worker_runs = transcription_worker_runs self.entity_worker_runs = entity_worker_runs self.allow_empty = allow_empty - self.mapping = LMTokenMapping() self.keep_spaces = keep_spaces - self.subword_vocab_size = subword_vocab_size - # Loading file from precedent extraction data_path = self.output / "split.json" - charset_path = self.output / "charset.pkl" - - is_data_file = data_path.exists() - is_charset_file = charset_path.exists() - - self.data: Dict = defaultdict(dict) - self.charset = set() - - if is_data_file and is_charset_file: - self.data.update(json.loads(data_path.read_bytes())) - self.charset.update(sorted(pickle.loads(charset_path.read_bytes()))) - elif is_data_file ^ is_charset_file: - raise FileNotFoundError( - f"The file '{data_path.name}' or `{charset_path.name}` is missing at location {self.output.as_posix()}" - ) - - self.language_corpus = defaultdict(list) - self.language_tokens = [] - self.language_lexicon = defaultdict(list) + # New keys can appear between several extractions + # We must explicitly define that this dict expects a dict as its value + self.data = defaultdict(dict) + if data_path.exists(): + self.data.update(json.loads(data_path.read_text())) # NER extraction self.translation_map: Dict[str, str] | None = get_translation_map(self.tokens) @@ -152,20 +126,11 @@ class ArkindexExtractor: ) ) - def format_text(self, text: str, charset: set | None = None): + def format_text(self, text: str): if not self.keep_spaces: text = normalize_spaces(text) text = normalize_linebreaks(text) - # Replace unknown characters by the unknown token - if charset is not None: - unknown_charset = set(text) - charset - text = text.translate( - { - ord(unknown_char): self.unknown_token - for unknown_char in unknown_charset - } - ) return text.strip() def process_element(self, dataset_parent: DatasetElement, element: Element): @@ -174,15 +139,7 @@ class ArkindexExtractor: The output path is directly related to the split of the element. """ text = self.extract_transcription(element) - - if self.unknown_token in text: - raise UnknownTokenInText(element_id=element.id) - - text = self.format_text( - text, - # Do not replace unknown characters in train split - charset=self.charset if dataset_parent.set_name != TRAIN_NAME else None, - ) + text = self.format_text(text) self.data[dataset_parent.set_name][element.id] = { "dataset_id": dataset_parent.dataset_id, @@ -193,8 +150,6 @@ class ArkindexExtractor: }, } - self.charset = self.charset.union(set(text)) - def process_parent(self, pbar, dataset_parent: DatasetElement): """ Extract data from a parent element. @@ -223,62 +178,6 @@ class ArkindexExtractor: except ProcessingError as e: logger.warning(f"Skipping {element.id}: {str(e)}") - def format_lm_files(self) -> None: - """ - Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset. - """ - logger.info("Preparing language resources") - # Add unknown token to charset - self.charset.add(self.unknown_token) - - # Build LM tokens - for token in sorted(list(self.charset)): - assert ( - token not in self.mapping.encode.values() - ), f"Special token {token} is reserved for language modeling." - self.language_tokens.append( - self.mapping.encode[token] - ) if token in self.mapping.encode else self.language_tokens.append(token) - self.language_tokens.append(self.mapping.ctc.encoded) - - # Build LM corpus - train_corpus = [ - values["text"].replace( - self.mapping.linebreak.display, self.mapping.space.display - ) - for values in self.data[TRAIN_NAME].values() - ] - - tokenizer = Tokenizer( - training_corpus=train_corpus, - charset=self.language_tokens, - unknown_token=self.unknown_token, - outdir=self.output / "language_model", - mapping=self.mapping, - tokens=self.tokens, - subword_vocab_size=self.subword_vocab_size, - ) - - if not tokenizer.sentencepiece_model: - return - - for level, tokenize in ( - ("characters", tokenizer.char_tokenize), - ("words", tokenizer.word_tokenize), - ("subwords", tokenizer.subword_tokenize), - ): - self.language_corpus[level] = list(map(tokenize, train_corpus)) - - # Build LM lexicon - self.language_lexicon["characters"] = [ - f"{token} {token}" for token in self.language_tokens - ] - for level in ["words", "subwords"]: - self.language_lexicon[level] = [ - f"{token} {tokenizer.char_tokenize(token)}" - for token in get_vocabulary(self.language_corpus[level]) - ] - def export(self): (self.output / "split.json").write_text( json.dumps( @@ -287,19 +186,6 @@ class ArkindexExtractor: indent=4, ) ) - for level in ["characters", "words", "subwords"]: - (self.output / "language_model" / f"corpus_{level}.txt").write_text( - "\n".join(self.language_corpus[level]) - ) - (self.output / "language_model" / f"lexicon_{level}.txt").write_text( - "\n".join(self.language_lexicon[level]) - ) - (self.output / "language_model" / "tokens.txt").write_text( - "\n".join(self.language_tokens) - ) - (self.output / "charset.pkl").write_bytes( - pickle.dumps(sorted(list(self.charset))) - ) def run(self): # Retrieve the Dataset and its splits from the cache @@ -337,7 +223,6 @@ class ArkindexExtractor: "No data was extracted using the provided export database and parameters." ) - self.format_lm_files() self.export() @@ -347,7 +232,6 @@ def run( element_type: List[str], output: Path, entity_separators: List[str], - unknown_token: str, tokens: Path, transcription_worker_versions: List[str | bool], entity_worker_versions: List[str | bool], @@ -355,20 +239,18 @@ def run( entity_worker_runs: List[str | bool], keep_spaces: bool, allow_empty: bool, - subword_vocab_size: int, ): assert database.exists(), f"No file found @ {database}" open_database(path=database) # Create directories - Path(output, LANGUAGE_DIR).mkdir(parents=True, exist_ok=True) + output.mkdir(parents=True, exist_ok=True) ArkindexExtractor( dataset_ids=dataset_ids, element_type=element_type, output=output, entity_separators=entity_separators, - unknown_token=unknown_token, tokens=tokens, transcription_worker_versions=transcription_worker_versions, entity_worker_versions=entity_worker_versions, @@ -376,5 +258,4 @@ def run( entity_worker_runs=entity_worker_runs, keep_spaces=keep_spaces, allow_empty=allow_empty, - subword_vocab_size=subword_vocab_size, ).run() diff --git a/dan/datasets/extract/exceptions.py b/dan/datasets/extract/exceptions.py index 486d9c93..3fd6b876 100644 --- a/dan/datasets/extract/exceptions.py +++ b/dan/datasets/extract/exceptions.py @@ -30,12 +30,3 @@ class NoTranscriptionError(ElementProcessingError): def __str__(self) -> str: return f"No transcriptions found on element ({self.element_id}) with this config. Skipping." - - -class UnknownTokenInText(ElementProcessingError): - """ - Raised when the unknown token is found in a transcription text - """ - - def __str__(self) -> str: - return f"Unknown token found in the transcription text of element ({self.element_id})" diff --git a/dan/datasets/extract/utils.py b/dan/datasets/extract/utils.py index b000b479..9cb3b3ce 100644 --- a/dan/datasets/extract/utils.py +++ b/dan/datasets/extract/utils.py @@ -2,21 +2,15 @@ # This code is licensed under CeCILL-C # -*- coding: utf-8 -*- -import itertools import logging -import operator import re from dataclasses import dataclass, field -from pathlib import Path -from tempfile import NamedTemporaryFile -from typing import Dict, Iterator, List +from typing import Dict, List -import sentencepiece as spm from lxml.etree import Element, SubElement, tostring -from nltk import wordpunct_tokenize from arkindex_export import TranscriptionEntity -from dan.utils import EntityType, LMTokenMapping +from dan.utils import EntityType logger = logging.getLogger(__name__) @@ -54,130 +48,6 @@ def normalize_spaces(text: str) -> str: return TRIM_SPACE_REGEX.sub(" ", text.strip()) -def get_vocabulary(tokenized_text: List[str]) -> set[str]: - """ - Compute set of vocabulary from tokenzied text. - :param tokenized_text: List of tokenized text. - """ - return sorted(set([token for doc in tokenized_text for token in doc.split()])) - - -@dataclass -class Tokenizer: - """ - A multi-level tokenizer (char, subword, word), where the subword tokenizer is trained using sentencepiece. - :param training_corpus: List of training text. - :param outdir: Path to save the subword tokenizer. - :param mapping: Mapping between displayed and encoded versions of special characters. - :param tokens: Start and end tokens used to represent named entities. - :param subword_vocab_size: Size of the vocabulary size to use to train the subword tokenizer. - """ - - training_corpus: List[str] - charset: List[str] - unknown_token: str - outdir: Path - mapping: LMTokenMapping - tokens: EntityType | None = None - subword_vocab_size: int = 1000 - sentencepiece_model: spm.SentencePieceProcessor = field(init=False) - - @property - def prefix(self): - return self.outdir / "subword_tokenizer" - - @property - def ner_tokens(self) -> List[str] | Iterator[str]: - if self.tokens is None: - return [] - return itertools.chain( - map(operator.attrgetter("start"), self.tokens.values()), - filter( - operator.truth, map(operator.attrgetter("end"), self.tokens.values()) - ), - ) - - @property - def mapping_tokens(self) -> List[str]: - return [token.encoded for token in self.mapping] - - @property - def special_tokens(self) -> List[str]: - return list(set(itertools.chain(self.mapping_tokens, self.ner_tokens))) - - def __post_init__(self) -> None: - """ - Train a sentencepiece model on the training corpus. - """ - # Write the corpus in a text file - logger.info("Training a sentencepiece model for subword tokenization") - with NamedTemporaryFile(dir=self.outdir, suffix=".txt", mode="w") as tmp: - tmp.write("\n".join(self.training_corpus)) - tmp.flush() - - try: - spm.SentencePieceTrainer.train( - input=tmp.name, - vocab_size=self.subword_vocab_size, - model_prefix=self.prefix, - user_defined_symbols=self.special_tokens, - minloglevel=1, - ) - except Exception as e: - logger.warning( - f"Failed to train a sentencepiece model for subword tokenization: {e} " - "Try again by editing the `--subword-vocab-size` parameter." - ) - self.sentencepiece_model = None - return - - # Load the model - self.sentencepiece_model = spm.SentencePieceProcessor( - model_file=str(self.prefix.with_suffix(".model")) - ) - - def subword_tokenize(self, text: str) -> str: - """ - Tokenize into subwords. Sampling is disabled to ensure reproducibility. - """ - tokens = self.sentencepiece_model.encode(text, out_type=str) - return " ".join(map("".join, map(self.encode, tokens))) - - def word_tokenize(self, text: str) -> str: - """ - Tokenize text into a string of space-separated words. Spaces (⎵) and NER tokens are considered as words. - :param text: Text to be tokenized. - """ - words = list(map("".join, map(self.encode, wordpunct_tokenize(text)))) - return " ".join( - [ - word + f" {self.mapping.space.encoded}" - if (i != len(words) - 1 and word not in self.ner_tokens) - else word - for i, word in enumerate(words) - ] - ) - - def char_tokenize(self, text: str) -> str: - """ - Tokenize text into a string of space-separated characters. - :param text: Text to be tokenized. - """ - return " ".join( - [ - char if char in self.charset else self.unknown_token - for char in self.encode(text) - ] - ) - - def encode(self, text: List[str]) -> List[str]: - """ - Encode special tokens. - :param text: Text to be encoded. - """ - return map(self.mapping.encode_token, text) - - def slugify(text: str): """ Replace invalid characters in text to underscores to use it as XML tag. diff --git a/docs/usage/datasets/download.md b/docs/usage/datasets/download.md index 221f6108..dcb7cb32 100644 --- a/docs/usage/datasets/download.md +++ b/docs/usage/datasets/download.md @@ -4,17 +4,22 @@ Use the `teklia-dan dataset download` command to download images of a dataset from a split extracted by DAN. This will: +- Store the set of characters encountered in the dataset (in the `charset.pkl` file), +- Generate the resources needed to build a n-gram language model at character, subword or word-level with [kenlm](https://github.com/kpu/kenlm) (in the `language_model/` folder). - Generate the images of each element (in the `images/` folder), - Create the mapping of the images that have been correctly uploaded (identified by its path) to the ground-truth transcription (with NER tokens if needed) (in the `labels.json` file). If an image download fails for whatever reason, it won't appear in the transcriptions file. The reason will be printed to stdout at the end of the process. Before trying to download the image, it checks that it wasn't downloaded previously. It is thus safe to run this command twice if a few images failed. -| Parameter | Description | Type | Default | -| ---------------- | -------------------------------------------------------------------------------- | -------------- | ------- | -| `--output` | Path where the `split.json` file is stored and where the data will be generated. | `pathlib.Path` | | -| `--max-width` | Images larger than this width will be resized to this width. | `int` | | -| `--max-height` | Images larger than this height will be resized to this height. | `int` | | -| `--image-format` | Images will be saved under this format. | `str` | `.jpg` | +| Parameter | Description | Type | Default | +| ---------------------- | ------------------------------------------------------------------------------------------------------------------- | -------------- | ------- | +| `--output` | Path where the `split.json` file is stored and where the data will be generated. | `pathlib.Path` | | +| `--max-width` | Images larger than this width will be resized to this width. | `int` | | +| `--max-height` | Images larger than this height will be resized to this height. | `int` | | +| `--image-format` | Images will be saved under this format. | `str` | `.jpg` | +| `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `â‡` | +| `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | | +| `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` | The `--output` directory should have a `split.json` JSON-formatted file with a specific format. A mapping of the elements (identified by its ID) to the image information and the ground-truth transcription (with NER tokens if needed). This file can be generated by the `teklia-dan dataset extract` command. More details in the [dedicated page](./extract.md). @@ -41,6 +46,32 @@ The `--output` directory should have a `split.json` JSON-formatted file with a s } ``` +The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively. This file can be generated by the `teklia-dan dataset tokens` command. More details in the [dedicated page](./tokens.md). + +```yaml +INTITULE: # Type of the entity on Arkindex + start: ⓘ # Starting token for this entity + end: â’¾ # Optional ending token for this entity +DATE: + start: â““ + end: â’¹ +COTE_SERIE: + start: â“¢ + end: Ⓢ +ANALYSE_COMPL.: + start: â“’ + end: â’¸ +PRECISIONS_SUR_COTE: + start: â“Ÿ + end: â“… +COTE_ARTICLE: + start: â“ + end: â’¶ +CLASSEMENT: + start: â“› + end: â“ +``` + ## Examples ### Download full images diff --git a/docs/usage/datasets/extract.md b/docs/usage/datasets/extract.md index cf0f0124..6c3b8618 100644 --- a/docs/usage/datasets/extract.md +++ b/docs/usage/datasets/extract.md @@ -5,8 +5,6 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkindex export database (SQLite format). This will: - Create a mapping of the elements (identified by its ID) to the image information and the ground-truth transcription (with NER tokens if needed) (in the `split.json` file), -- Store the set of characters encountered in the dataset (in the `charset.pkl` file), -- Generate the resources needed to build a n-gram language model at character, subword or word-level with [kenlm](https://github.com/kpu/kenlm) (in the `language_model/` folder). | Parameter | Description | Type | Default | | --------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- | @@ -15,7 +13,6 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind | `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | | | `--output` | Folder where the data will be generated. | `pathlib.Path` | | | `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text (see [dedicated section](#examples)). | `str` | | -| `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `â‡` | | `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | | | `--transcription-worker-versions` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | | | `--entity-worker-versions` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | | @@ -23,7 +20,6 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind | `--entity-worker-runs` | Filter transcriptions entities by worker_runs. Use `manual` for manual filtering | `str` or `uuid` | | | `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` | | `--allow-empty` | Elements with no transcriptions are skipped by default. This flag disables this behaviour. | `bool` | `False` | -| `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` | The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively. This file can be generated by the `teklia-dan dataset tokens` command. More details in the [dedicated page](./tokens.md). diff --git a/tests/__init__.py b/tests/__init__.py index cc43796c..7417ec29 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,7 +2,58 @@ # This code is licensed under CeCILL-C # -*- coding: utf-8 -*- - +import re from pathlib import Path FIXTURES = Path(__file__).resolve().parent / "data" + +TWO_SPACES_REGEX = re.compile(r" {2}") + + +def change_split_content( + load_entities: bool, + transcription_entities_worker_version: str | bool, + keep_spaces: bool, + split_content: dict, + tokens: list, + expected_labels: dict = [], +): + # Transcriptions with worker version are in lowercase + if transcription_entities_worker_version: + for split in split_content: + for element_id in split_content[split]: + split_content[split][element_id]["text"] = split_content[split][ + element_id + ]["text"].lower() + for split in expected_labels: + for image in expected_labels[split]: + expected_labels[split][image] = expected_labels[split][image].lower() + + # If we do not load entities, remove tokens + if not load_entities: + token_translations = {ord(token): None for token in tokens} + for split in split_content: + for element_id in split_content[split]: + split_content[split][element_id]["text"] = split_content[split][ + element_id + ]["text"].translate(token_translations) + for split in expected_labels: + for image in expected_labels[split]: + expected_labels[split][image] = expected_labels[split][image].translate( + token_translations + ) + + # Replace double spaces with regular space + if not keep_spaces: + for split in split_content: + for element_id in split_content[split]: + split_content[split][element_id]["text"] = TWO_SPACES_REGEX.sub( + " ", split_content[split][element_id]["text"] + ) + for split in expected_labels: + for image in expected_labels[split]: + expected_labels[split][image] = TWO_SPACES_REGEX.sub( + " ", expected_labels[split][image] + ) + + return split_content, expected_labels diff --git a/tests/data/extraction/split.json b/tests/data/extraction/split.json index 143073ab..a101df54 100644 --- a/tests/data/extraction/split.json +++ b/tests/data/extraction/split.json @@ -27,7 +27,7 @@ ] ] }, - "text": "â“¢Leunaut â“•Clauâ‡e â“‘â‡â‡" + "text": "â“¢Leunaut â“•Claude â“‘49" }, "test-page_1-line_2": { "dataset_id": "dataset_id", @@ -56,7 +56,7 @@ ] ] }, - "text": "â“¢â‡auracâ‡o â“•Clauâ‡ine â“‘â‡â‡" + "text": "â“¢Bauracho â“•Claudine â“‘39" }, "test-page_1-line_3": { "dataset_id": "dataset_id", @@ -85,7 +85,7 @@ ] ] }, - "text": "â“¢Laurent â“•Jacâ‡use â“‘21" + "text": "â“¢Laurent â“•Jacquse â“‘21" }, "test-page_2-line_1": { "dataset_id": "dataset_id", @@ -114,7 +114,7 @@ ] ] }, - "text": "â“¢â‡alette â“•Elisaâ‡et⇠ⓑ7â‡" + "text": "â“¢Valette â“•Elisabeth â“‘76" }, "test-page_2-line_2": { "dataset_id": "dataset_id", @@ -143,7 +143,7 @@ ] ] }, - "text": "â“¢Tanâ‡ol â“•Jean â“‘7â‡" + "text": "â“¢Tanbol â“•Jean â“‘76" }, "test-page_2-line_3": { "dataset_id": "dataset_id", @@ -172,7 +172,7 @@ ] ] }, - "text": "â“¢â‡auret â“•Jean â“‘â‡â‡" + "text": "â“¢Vauret â“•Jean â“‘64" } }, "train": { @@ -408,7 +408,7 @@ ] ] }, - "text": "â“¢Cirau⇠ⓕAntoine â“‘â‡â‡" + "text": "â“¢Ciraud â“•Antoine â“‘34" }, "val-page_1-line_2": { "dataset_id": "dataset_id", @@ -437,7 +437,7 @@ ] ] }, - "text": "â“¢Cirau⇠ⓕPriser â“‘â‡â‡" + "text": "â“¢Ciraud â“•Priser â“‘34" }, "val-page_1-line_3": { "dataset_id": "dataset_id", @@ -466,7 +466,7 @@ ] ] }, - "text": "â“¢Cirau⇠ⓕElisaâ‡et⇠ⓑâ‡â‡" + "text": "â“¢Ciraud â“•Elisabeth â“‘34" } } } diff --git a/tests/test_download.py b/tests/test_download.py index 1c1879fb..ea818ed2 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -2,9 +2,10 @@ # This code is licensed under CeCILL-C # -*- coding: utf-8 -*- - import json import logging +import pickle +import re from operator import attrgetter, methodcaller from pathlib import Path @@ -13,11 +14,15 @@ from PIL import Image, ImageChops from dan.datasets.download.images import IIIF_FULL_SIZE, ImageDownloader from dan.datasets.download.utils import download_image +from dan.utils import parse_tokens from line_image_extractor.image_utils import BoundingBox -from tests import FIXTURES +from tests import FIXTURES, change_split_content EXTRACTION_DATA_PATH = FIXTURES / "extraction" +ENTITY_TOKEN_SPACE = re.compile(r"[â“¢|â“•|â“‘] ") +TWO_SPACES_LM_REGEX = re.compile(r"â– â–") + @pytest.mark.parametrize( "max_width, max_height, width, height, resize", @@ -28,25 +33,202 @@ EXTRACTION_DATA_PATH = FIXTURES / "extraction" (1000, 2000, 2000, 3000, "1000,"), ), ) -def test_get_iiif_size_arg(max_width, max_height, width, height, resize): +def test_get_iiif_size_arg(max_width, max_height, width, height, resize, tmp_path): + split_path = tmp_path / "output" / "split.json" + split_path.parent.mkdir() + split_path.write_text(json.dumps({"train": {}})) + assert ( - ImageDownloader(max_width=max_width, max_height=max_height).get_iiif_size_arg( - width=width, height=height - ) + ImageDownloader( + output=split_path.parent, max_width=max_width, max_height=max_height + ).get_iiif_size_arg(width=width, height=height) == resize ) -def test_download(split_content, monkeypatch, tmp_path): +@pytest.mark.parametrize( + "load_entities,keep_spaces,transcription_entities_worker_version,expected_subword_language_corpus,subword_vocab_size", + ( + ( + True, + True, + "worker_version_id", + """â– â“¢ l a u l ont â– â“• f r an c oi s â– â“‘ 8 +â– â“¢ c i re t â– â“• an t oi ne â– â“‘ 2 7 +â– â“¢ c i re t â– â“• m a r ie â– â“‘ 2 8 +â– â“¢ c i re t â– â“• m a r ie â– â“‘ 2 +â– â“¢ e u re s t on â– â“• so l an g e â– â“‘ 1 0 +â– â“¢ t e r ont u s s ie u x â– â“• j e an â– â“‘ 2 +â– â“¢ p re s s on e t â– â“• m a r ie â– â“‘ 1 2""", + 40, + ), + ( + True, + False, + "worker_version_id", + """â– â“¢ l a u l ont â– â“• f r an c oi s â– â“‘ 8 +â– â“¢ c i re t â– â“• an t oi ne â– â“‘ 2 7 +â– â“¢ c i re t â– â“• m a r ie â– â“‘ 2 8 +â– â“¢ c i re t â– â“• m a r ie â– â“‘ 2 +â– â“¢ e u re s t on â– â“• so l an g e â– â“‘ 1 0 +â– â“¢ t e r ont u s s ie u x â– â“• j e an â– â“‘ 2 +â– â“¢ p re s s on e t â– â“• m a r ie â– â“‘ 1 2""", + 40, + ), + ( + False, + True, + "worker_version_id", + """â– la u l ont â– f r an c oi s â– 8 +â– c i re t â– an t oi ne â– 2 7 +â– c i re t â– m a r ie â– 2 8 +â– c i re t â– m a r ie â– 2 +â– e u res t on â– so l an g e â– 1 0 +â– t e r ont u ss ie u x â– j e an â– 2 +â– p res so ne t â– m a r ie â– 1 2""", + 40, + ), + ( + False, + False, + "worker_version_id", + """â– la u l ont â– f r an c oi s â– 8 +â– c i re t â– an t oi ne â– 2 7 +â– c i re t â– m a r ie â– 2 8 +â– c i re t â– m a r ie â– 2 +â– e u res t on â– so l an g e â– 1 0 +â– t e r ont u ss ie u x â– j e an â– 2 +â– p res so ne t â– m a r ie â– 1 2""", + 40, + ), + ( + True, + True, + False, + """â– â“¢ L a u l o n t â– â“• F r a n c o i s â– â“‘ 8 +â– â“¢ C i r e t â– â“• A n t o i n e â– â“‘ 2 7 +â– â“¢ C i r e t â– â“• M a r ie â– â“‘ 2 8 +â– â“¢ C i r e t â– â“• M a r ie â– â“‘ 2 +â– â“¢ E u r e s t o n â– â“• S o l a n g e â– â“‘ 1 0 +â– â“¢ T e r o n t u s s ie u x â– â“• J e a n â– â“‘ 2 +â– â“¢ P r e s s o n e t â– â“• M a r ie â– â“‘ 1 2""", + 40, + ), + ( + True, + True, + False, + """â– â“¢ L a u l ont â– â“• F r an c oi s â– â“‘ 8 +â– â“¢ C i re t â– â“• A n t oi n e â– â“‘ 2 7 +â– â“¢ C i re t â– â“• M a r ie â– â“‘ 2 8 +â– â“¢ C i re t â– â“• M a r ie â– â“‘ 2 +â– â“¢ E u re s t on â– â“• S o l an g e â– â“‘ 1 0 +â– â“¢ T e r ont u s s ie u x â– â“• J e an â– â“‘ 2 +â– â“¢ P re s s on e t â– â“• M a r ie â– â“‘ 1 2""", + 45, + ), + ( + True, + False, + False, + """â– â“¢ L a u l o n t â– â“• F r a n c o i s â– â“‘ 8 +â– â“¢ C i r e t â– â“• A n t o i n e â– â“‘ 2 7 +â– â“¢ C i r e t â– â“• M a r ie â– â“‘ 2 8 +â– â“¢ C i r e t â– â“• M a r ie â– â“‘ 2 +â– â“¢ E u r e s t o n â– â“• S o l a n g e â– â“‘ 1 0 +â– â“¢ T e r o n t u s s ie u x â– â“• J e a n â– â“‘ 2 +â– â“¢ P r e s s o n e t â– â“• M a r ie â– â“‘ 1 2""", + 40, + ), + ( + False, + True, + False, + """â– L a u l ont â– F r an c oi s â– 8 +â– C i re t â– A n t oi n e â– 2 7 +â– C i re t â– M a r ie â– 2 8 +â– C i re t â– M a r ie â– 2 +â– E u re s t on â– S o l an g e â– 1 0 +â– T e r ont u s s ie u x â– J e an â– 2 +â– P re s s on e t â– M a r ie â– 1 2""", + 40, + ), + ( + False, + False, + False, + """â– L a u l ont â– F r an c oi s â– 8 +â– C i re t â– A n t oi n e â– 2 7 +â– C i re t â– M a r ie â– 2 8 +â– C i re t â– M a r ie â– 2 +â– E u re s t on â– S o l an g e â– 1 0 +â– T e r ont u s s ie u x â– J e an â– 2 +â– P re s s on e t â– M a r ie â– 1 2""", + 40, + ), + ), +) +def test_download( + load_entities, + keep_spaces, + transcription_entities_worker_version, + expected_subword_language_corpus, + subword_vocab_size, + split_content, + monkeypatch, + tmp_path, +): + output = tmp_path / "download" + (output / "language_model").mkdir(parents=True, exist_ok=True) + + # Mock tokens + tokens_path = EXTRACTION_DATA_PATH / "tokens.yml" + tokens = [ + token + for entity_type in parse_tokens(tokens_path).values() + for token in [entity_type.start, entity_type.end] + if token + ] + + # Mock "split.json" + split_content, expected_labels = change_split_content( + load_entities, + transcription_entities_worker_version, + keep_spaces, + split_content, + tokens, + { + "test": { + "images/test/dataset_id/test-page_1-line_1.jpg": "â“¢Leunaut â“•Clauâ‡e â“‘â‡â‡", + "images/test/dataset_id/test-page_1-line_2.jpg": "â“¢â‡auracâ‡o â“•Clauâ‡ine â“‘â‡â‡", + "images/test/dataset_id/test-page_1-line_3.jpg": "â“¢Laurent â“•Jacâ‡use â“‘21", + "images/test/dataset_id/test-page_2-line_1.jpg": "â“¢â‡alette â“•Elisaâ‡et⇠ⓑ7â‡", + "images/test/dataset_id/test-page_2-line_2.jpg": "â“¢Tanâ‡ol â“•Jean â“‘7â‡", + "images/test/dataset_id/test-page_2-line_3.jpg": "â“¢â‡auret â“•Jean â“‘â‡â‡", + }, + "train": { + "images/train/dataset_id/train-page_1-line_1.jpg": "â“¢Laulont â“•Francois â“‘8", + "images/train/dataset_id/train-page_1-line_2.jpg": "â“¢Ciret â“•Antoine â“‘27", + "images/train/dataset_id/train-page_1-line_3.jpg": "â“¢Ciret â“•Marie â“‘28", + "images/train/dataset_id/train-page_1-line_4.jpg": "â“¢Ciret â“•Marie â“‘2", + "images/train/dataset_id/train-page_2-line_1.jpg": "â“¢Eureston â“•Solange â“‘10", + "images/train/dataset_id/train-page_2-line_2.jpg": "â“¢Terontussieux â“•Jean â“‘2", + "images/train/dataset_id/train-page_2-line_3.jpg": "â“¢Pressonet â“•Marie â“‘12", + }, + "val": { + "images/val/dataset_id/val-page_1-line_1.jpg": "â“¢Cirau⇠ⓕAntoine â“‘â‡â‡", + "images/val/dataset_id/val-page_1-line_2.jpg": "â“¢Cirau⇠ⓕPriser â“‘â‡â‡", + "images/val/dataset_id/val-page_1-line_3.jpg": "â“¢Cirau⇠ⓕElisaâ‡et⇠ⓑâ‡â‡", + }, + }, + ) + (output / "split.json").write_text(json.dumps(split_content)) + # Mock download_image so that it simply opens it with Pillow monkeypatch.setattr( "dan.datasets.download.images.download_image", lambda url: Image.open(url) ) - output = tmp_path / "download" - output.mkdir(parents=True, exist_ok=True) - (output / "split.json").write_text(json.dumps(split_content)) - def mock_build_image_url(polygon, image_url, *args, **kwargs): # During tests, the image URL is its local path return image_url @@ -54,6 +236,8 @@ def test_download(split_content, monkeypatch, tmp_path): extractor = ImageDownloader( output=output, image_extension=".jpg", + tokens=tokens_path if load_entities else None, + subword_vocab_size=subword_vocab_size, ) # Mock build_image_url to simply return the path to the image extractor.build_iiif_url = mock_build_image_url @@ -66,6 +250,7 @@ def test_download(split_content, monkeypatch, tmp_path): VAL_DIR = IMAGE_DIR / "val" / "dataset_id" expected_paths = [ + output / "charset.pkl", # Images of test folder TEST_DIR / "test-page_1-line_1.jpg", TEST_DIR / "test-page_1-line_2.jpg", @@ -86,38 +271,124 @@ def test_download(split_content, monkeypatch, tmp_path): VAL_DIR / "val-page_1-line_2.jpg", VAL_DIR / "val-page_1-line_3.jpg", output / "labels.json", + # Language resources + output / "language_model" / "corpus_characters.txt", + output / "language_model" / "corpus_subwords.txt", + output / "language_model" / "corpus_words.txt", + output / "language_model" / "lexicon_characters.txt", + output / "language_model" / "lexicon_subwords.txt", + output / "language_model" / "lexicon_words.txt", + output / "language_model" / "subword_tokenizer.model", + output / "language_model" / "subword_tokenizer.vocab", + output / "language_model" / "tokens.txt", output / "split.json", ] assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths - # Check "labels.json" - expected_labels = { - "test": { - "images/test/dataset_id/test-page_1-line_1.jpg": "â“¢Leunaut â“•Clauâ‡e â“‘â‡â‡", - "images/test/dataset_id/test-page_1-line_2.jpg": "â“¢â‡auracâ‡o â“•Clauâ‡ine â“‘â‡â‡", - "images/test/dataset_id/test-page_1-line_3.jpg": "â“¢Laurent â“•Jacâ‡use â“‘21", - "images/test/dataset_id/test-page_2-line_1.jpg": "â“¢â‡alette â“•Elisaâ‡et⇠ⓑ7â‡", - "images/test/dataset_id/test-page_2-line_2.jpg": "â“¢Tanâ‡ol â“•Jean â“‘7â‡", - "images/test/dataset_id/test-page_2-line_3.jpg": "â“¢â‡auret â“•Jean â“‘â‡â‡", - }, - "train": { - "images/train/dataset_id/train-page_1-line_1.jpg": "â“¢Laulont â“•Francois â“‘8", - "images/train/dataset_id/train-page_1-line_2.jpg": "â“¢Ciret â“•Antoine â“‘27", - "images/train/dataset_id/train-page_1-line_3.jpg": "â“¢Ciret â“•Marie â“‘28", - "images/train/dataset_id/train-page_1-line_4.jpg": "â“¢Ciret â“•Marie â“‘2", - "images/train/dataset_id/train-page_2-line_1.jpg": "â“¢Eureston â“•Solange â“‘10", - "images/train/dataset_id/train-page_2-line_2.jpg": "â“¢Terontussieux â“•Jean â“‘2", - "images/train/dataset_id/train-page_2-line_3.jpg": "â“¢Pressonet â“•Marie â“‘12", - }, - "val": { - "images/val/dataset_id/val-page_1-line_1.jpg": "â“¢Cirau⇠ⓕAntoine â“‘â‡â‡", - "images/val/dataset_id/val-page_1-line_2.jpg": "â“¢Cirau⇠ⓕPriser â“‘â‡â‡", - "images/val/dataset_id/val-page_1-line_3.jpg": "â“¢Cirau⇠ⓕElisaâ‡et⇠ⓑâ‡â‡", - }, - } + # Check "charset.pkl" + expected_charset = {"â‡"} + for values in split_content["train"].values(): + expected_charset.update(set(values["text"])) + if load_entities: + expected_charset.update(tokens) + + assert set(pickle.loads((output / "charset.pkl").read_bytes())) == expected_charset + + # Check "labels.json" assert json.loads((output / "labels.json").read_text()) == expected_labels + # Check "language_corpus.txt" + expected_char_language_corpus = """â“¢ L a u l o n t â– â– â“• F r a n c o i s â– â– â“‘ 8 +â“¢ C i r e t â– â– â“• A n t o i n e â– â– â“‘ 2 7 +â“¢ C i r e t â– â– â“• M a r i e â– â– â“‘ 2 8 +â“¢ C i r e t â– â– â“• M a r i e â– â– â“‘ 2 +â“¢ E u r e s t o n â– â– â“• S o l a n g e â– â– â“‘ 1 0 +â“¢ T e r o n t u s s i e u x â– â– â“• J e a n â– â– â“‘ 2 +â“¢ P r e s s o n e t â– â– â“• M a r i e â– â– â“‘ 1 2""" + + expected_word_language_corpus = """â“¢ Laulont â– â“• Francois â– â“‘ 8 +â“¢ Ciret â– â“• Antoine â– â“‘ 27 +â“¢ Ciret â– â“• Marie â– â“‘ 28 +â“¢ Ciret â– â“• Marie â– â“‘ 2 +â“¢ Eureston â– â“• Solange â– â“‘ 10 +â“¢ Terontussieux â– â“• Jean â– â“‘ 2 +â“¢ Pressonet â– â“• Marie â– â“‘ 12""" + + # Transcriptions with worker version are in lowercase + if transcription_entities_worker_version: + expected_char_language_corpus = expected_char_language_corpus.lower() + expected_word_language_corpus = expected_word_language_corpus.lower() + expected_subword_language_corpus = expected_subword_language_corpus.lower() + + # If we do not load entities, remove tokens + if not load_entities: + expected_char_language_corpus = ENTITY_TOKEN_SPACE.sub( + "", expected_char_language_corpus + ) + expected_word_language_corpus = ENTITY_TOKEN_SPACE.sub( + "", expected_word_language_corpus + ) + expected_subword_language_corpus = ENTITY_TOKEN_SPACE.sub( + "", expected_subword_language_corpus + ) + # Replace double spaces with regular space + if not keep_spaces: + expected_char_language_corpus = TWO_SPACES_LM_REGEX.sub( + "â–", expected_char_language_corpus + ) + expected_word_language_corpus = TWO_SPACES_LM_REGEX.sub( + "â–", expected_word_language_corpus + ) + expected_subword_language_corpus = TWO_SPACES_LM_REGEX.sub( + "â–", expected_subword_language_corpus + ) + + assert ( + output / "language_model" / "corpus_characters.txt" + ).read_text() == expected_char_language_corpus + + assert ( + output / "language_model" / "corpus_words.txt" + ).read_text() == expected_word_language_corpus + + assert ( + output / "language_model" / "corpus_subwords.txt" + ).read_text() == expected_subword_language_corpus + + # Check "language_tokens.txt" + expected_language_tokens = [ + "â–" if t.isspace() else t for t in sorted(list(expected_charset)) + ] + expected_language_tokens.append("â—Œ") + assert (output / "language_model" / "tokens.txt").read_text() == "\n".join( + expected_language_tokens + ) + + # Check "language_lexicon.txt" + expected_language_char_lexicon = [f"{t} {t}" for t in expected_language_tokens] + assert ( + output / "language_model" / "lexicon_characters.txt" + ).read_text() == "\n".join(expected_language_char_lexicon) + + word_vocab = set([word for word in expected_word_language_corpus.split()]) + expected_language_word_lexicon = [ + f"{word} {' '.join(word)}" for word in sorted(word_vocab) + ] + assert (output / "language_model" / "lexicon_words.txt").read_text() == "\n".join( + expected_language_word_lexicon + ) + + subword_vocab = set( + [subword for subword in expected_subword_language_corpus.split()] + ) + expected_language_subword_lexicon = [ + f"{subword} {' '.join(subword)}" for subword in sorted(subword_vocab) + ] + assert ( + output / "language_model" / "lexicon_subwords.txt" + ).read_text() == "\n".join(expected_language_subword_lexicon) + # Check cropped images for expected_path in expected_paths: if expected_path.suffix != ".jpg": @@ -129,7 +400,7 @@ def test_download(split_content, monkeypatch, tmp_path): ) -def test_download_image_error(monkeypatch, caplog, capsys): +def test_download_image_error(monkeypatch, caplog, capsys, tmp_path): task = { "split": "train", "polygon": [], @@ -141,7 +412,11 @@ def test_download_image_error(monkeypatch, caplog, capsys): lambda polygon: BoundingBox(0, 0, 0, 0), ) - extractor = ImageDownloader(image_extension=".jpg") + split_path = tmp_path / "output" / "split.json" + split_path.parent.mkdir() + split_path.write_text(json.dumps({"train": {}})) + + extractor = ImageDownloader(output=split_path.parent, image_extension=".jpg") # Add the key in data extractor.data[task["split"]][str(task["destination"])] = "deadbeefdata" diff --git a/tests/test_extract.py b/tests/test_extract.py index 922d1026..60605571 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -4,7 +4,6 @@ # -*- coding: utf-8 -*- import json -import pickle import re from operator import methodcaller from typing import NamedTuple @@ -12,7 +11,6 @@ from typing import NamedTuple import pytest from arkindex_export import ( - DatasetElement, Element, Transcription, TranscriptionEntity, @@ -21,7 +19,6 @@ from dan.datasets.extract.arkindex import ArkindexExtractor from dan.datasets.extract.db import get_transcription_entities from dan.datasets.extract.exceptions import ( NoTranscriptionError, - UnknownTokenInText, ) from dan.datasets.extract.utils import ( EntityType, @@ -30,11 +27,10 @@ from dan.datasets.extract.utils import ( normalize_spaces, ) from dan.utils import parse_tokens -from tests import FIXTURES +from tests import FIXTURES, change_split_content EXTRACTION_DATA_PATH = FIXTURES / "extraction" -TWO_SPACES_REGEX = re.compile(r" {2}") ENTITY_TOKEN_SPACE = re.compile(r"[â“¢|â“•|â“‘] ") TWO_SPACES_LM_REGEX = re.compile(r"â– â–") @@ -89,145 +85,10 @@ def test_normalize_linebreaks(text, trimmed): assert normalize_linebreaks(text) == trimmed -def test_process_element_unknown_token_in_text_error(mock_database, tmp_path): - output = tmp_path / "extraction" - arkindex_extractor = ArkindexExtractor(output=output) - - # Retrieve a dataset element and update its transcription with an invalid one - dataset_element = DatasetElement.select().first() - element = dataset_element.element - Transcription.update({Transcription.text: "Is this text validâ‡"}).execute() - - with pytest.raises( - UnknownTokenInText, - match=re.escape( - f"Unknown token found in the transcription text of element ({element.id})" - ), - ): - arkindex_extractor.process_element(dataset_element, element) - - +@pytest.mark.parametrize("load_entities", [True, False]) +@pytest.mark.parametrize("keep_spaces", [True, False]) @pytest.mark.parametrize( - "load_entities,keep_spaces,transcription_entities_worker_version,expected_subword_language_corpus,subword_vocab_size", - ( - ( - True, - True, - "worker_version_id", - """â– â“¢ l a u l ont â– â“• f r an c oi s â– â“‘ 8 -â– â“¢ c i re t â– â“• an t oi ne â– â“‘ 2 7 -â– â“¢ c i re t â– â“• m a r ie â– â“‘ 2 8 -â– â“¢ c i re t â– â“• m a r ie â– â“‘ 2 -â– â“¢ e u re s t on â– â“• so l an g e â– â“‘ 1 0 -â– â“¢ t e r ont u s s ie u x â– â“• j e an â– â“‘ 2 -â– â“¢ p re s s on e t â– â“• m a r ie â– â“‘ 1 2""", - 40, - ), - ( - True, - False, - "worker_version_id", - """â– â“¢ l a u l ont â– â“• f r an c oi s â– â“‘ 8 -â– â“¢ c i re t â– â“• an t oi ne â– â“‘ 2 7 -â– â“¢ c i re t â– â“• m a r ie â– â“‘ 2 8 -â– â“¢ c i re t â– â“• m a r ie â– â“‘ 2 -â– â“¢ e u re s t on â– â“• so l an g e â– â“‘ 1 0 -â– â“¢ t e r ont u s s ie u x â– â“• j e an â– â“‘ 2 -â– â“¢ p re s s on e t â– â“• m a r ie â– â“‘ 1 2""", - 40, - ), - ( - False, - True, - "worker_version_id", - """â– la u l ont â– f r an c oi s â– 8 -â– c i re t â– an t oi ne â– 2 7 -â– c i re t â– m a r ie â– 2 8 -â– c i re t â– m a r ie â– 2 -â– e u res t on â– so l an g e â– 1 0 -â– t e r ont u ss ie u x â– j e an â– 2 -â– p res so ne t â– m a r ie â– 1 2""", - 40, - ), - ( - False, - False, - "worker_version_id", - """â– la u l ont â– f r an c oi s â– 8 -â– c i re t â– an t oi ne â– 2 7 -â– c i re t â– m a r ie â– 2 8 -â– c i re t â– m a r ie â– 2 -â– e u res t on â– so l an g e â– 1 0 -â– t e r ont u ss ie u x â– j e an â– 2 -â– p res so ne t â– m a r ie â– 1 2""", - 40, - ), - ( - True, - True, - False, - """â– â“¢ L a u l o n t â– â“• F r a n c o i s â– â“‘ 8 -â– â“¢ C i r e t â– â“• A n t o i n e â– â“‘ 2 7 -â– â“¢ C i r e t â– â“• M a r ie â– â“‘ 2 8 -â– â“¢ C i r e t â– â“• M a r ie â– â“‘ 2 -â– â“¢ E u r e s t o n â– â“• S o l a n g e â– â“‘ 1 0 -â– â“¢ T e r o n t u s s ie u x â– â“• J e a n â– â“‘ 2 -â– â“¢ P r e s s o n e t â– â“• M a r ie â– â“‘ 1 2""", - 40, - ), - ( - True, - True, - False, - """â– â“¢ L a u l ont â– â“• F r an c oi s â– â“‘ 8 -â– â“¢ C i re t â– â“• A n t oi n e â– â“‘ 2 7 -â– â“¢ C i re t â– â“• M a r ie â– â“‘ 2 8 -â– â“¢ C i re t â– â“• M a r ie â– â“‘ 2 -â– â“¢ E u re s t on â– â“• S o l an g e â– â“‘ 1 0 -â– â“¢ T e r ont u s s ie u x â– â“• J e an â– â“‘ 2 -â– â“¢ P re s s on e t â– â“• M a r ie â– â“‘ 1 2""", - 45, - ), - ( - True, - False, - False, - """â– â“¢ L a u l o n t â– â“• F r a n c o i s â– â“‘ 8 -â– â“¢ C i r e t â– â“• A n t o i n e â– â“‘ 2 7 -â– â“¢ C i r e t â– â“• M a r ie â– â“‘ 2 8 -â– â“¢ C i r e t â– â“• M a r ie â– â“‘ 2 -â– â“¢ E u r e s t o n â– â“• S o l a n g e â– â“‘ 1 0 -â– â“¢ T e r o n t u s s ie u x â– â“• J e a n â– â“‘ 2 -â– â“¢ P r e s s o n e t â– â“• M a r ie â– â“‘ 1 2""", - 40, - ), - ( - False, - True, - False, - """â– L a u l ont â– F r an c oi s â– 8 -â– C i re t â– A n t oi n e â– 2 7 -â– C i re t â– M a r ie â– 2 8 -â– C i re t â– M a r ie â– 2 -â– E u re s t on â– S o l an g e â– 1 0 -â– T e r ont u s s ie u x â– J e an â– 2 -â– P re s s on e t â– M a r ie â– 1 2""", - 40, - ), - ( - False, - False, - False, - """â– L a u l ont â– F r an c oi s â– 8 -â– C i re t â– A n t oi n e â– 2 7 -â– C i re t â– M a r ie â– 2 8 -â– C i re t â– M a r ie â– 2 -â– E u re s t on â– S o l an g e â– 1 0 -â– T e r ont u s s ie u x â– J e an â– 2 -â– P re s s on e t â– M a r ie â– 1 2""", - 40, - ), - ), + "transcription_entities_worker_version", ["worker_version_id", False] ) @pytest.mark.parametrize("existing", ((True, False))) def test_extract( @@ -236,14 +97,13 @@ def test_extract( transcription_entities_worker_version, split_content, mock_database, - expected_subword_language_corpus, - subword_vocab_size, tmp_path, existing, ): output = tmp_path / "extraction" output.mkdir(parents=True, exist_ok=True) - (output / "language_model").mkdir(parents=True, exist_ok=True) + + # Mock tokens tokens_path = EXTRACTION_DATA_PATH / "tokens.yml" tokens = [ token @@ -252,17 +112,7 @@ def test_extract( if token ] - # Add character to fake previous extract file in the folder - previous_character = "%" - if existing: - charset_path = output / "charset.pkl" - data_path = output / "split.json" - - dataset_type = "train" - - data_id = "train-page_1-line_5" - data = { "dataset-id": "dataset-id", "image": { @@ -275,24 +125,14 @@ def test_extract( [37, 191], ], }, - "text": previous_character, + "text": "%", } - charset_path.write_bytes(pickle.dumps([previous_character])) - data_path.write_text( - json.dumps( - {dataset_type: {data_id: data}}, - ) + (output / "split.json").write_text( + json.dumps({"train": {"train-page_1-line_5": data}}) ) - split_content[dataset_type][data_id] = data - - keys = list(split_content["train"].keys()) - keys.sort() - split_content["train"] = {i: split_content["train"][i] for i in keys} - - # Add 1 to subword_vocab_size because we have one more subword who is {previous_character} - subword_vocab_size += 1 + split_content["train"]["train-page_1-line_5"] = data extractor = ArkindexExtractor( dataset_ids=["dataset_id"], @@ -300,173 +140,29 @@ def test_extract( output=output, # Keep the whole text entity_separators=None, - tokens=tokens_path if load_entities else None, transcription_worker_versions=[transcription_entities_worker_version], entity_worker_versions=[transcription_entities_worker_version] if load_entities else [], + tokens=tokens_path if load_entities else None, keep_spaces=keep_spaces, - subword_vocab_size=subword_vocab_size, ) extractor.run() - expected_paths = [ - output / "charset.pkl", - # Language resources - output / "language_model" / "corpus_characters.txt", - output / "language_model" / "corpus_subwords.txt", - output / "language_model" / "corpus_words.txt", - output / "language_model" / "lexicon_characters.txt", - output / "language_model" / "lexicon_subwords.txt", - output / "language_model" / "lexicon_words.txt", - output / "language_model" / "subword_tokenizer.model", - output / "language_model" / "subword_tokenizer.vocab", - output / "language_model" / "tokens.txt", - output / "split.json", + assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == [ + output / "split.json" ] - assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths - - # Check "split.json" - # Transcriptions with worker version are in lowercase - if transcription_entities_worker_version: - for split in split_content: - for element_id in split_content[split]: - split_content[split][element_id]["text"] = split_content[split][ - element_id - ]["text"].lower() - - # If we do not load entities, remove tokens - if not load_entities: - token_translations = {ord(token): None for token in tokens} - for split in split_content: - for element_id in split_content[split]: - split_content[split][element_id]["text"] = split_content[split][ - element_id - ]["text"].translate(token_translations) - - # Replace double spaces with regular space - if not keep_spaces: - for split in split_content: - for element_id in split_content[split]: - split_content[split][element_id]["text"] = TWO_SPACES_REGEX.sub( - " ", split_content[split][element_id]["text"] - ) - - assert json.loads((output / "split.json").read_text()) == split_content - - # Check "charset.pkl" - expected_charset = set() - for values in split_content["train"].values(): - expected_charset.update(set(values["text"])) - - if load_entities: - expected_charset.update(tokens) - expected_charset.add("â‡") - assert set(pickle.loads((output / "charset.pkl").read_bytes())) == expected_charset - - # Check "language_corpus.txt" - expected_char_language_corpus = """â“¢ L a u l o n t â– â– â“• F r a n c o i s â– â– â“‘ 8 -â“¢ C i r e t â– â– â“• A n t o i n e â– â– â“‘ 2 7 -â“¢ C i r e t â– â– â“• M a r i e â– â– â“‘ 2 8 -â“¢ C i r e t â– â– â“• M a r i e â– â– â“‘ 2 -â“¢ E u r e s t o n â– â– â“• S o l a n g e â– â– â“‘ 1 0 -â“¢ T e r o n t u s s i e u x â– â– â“• J e a n â– â– â“‘ 2 -â“¢ P r e s s o n e t â– â– â“• M a r i e â– â– â“‘ 1 2""" - - expected_word_language_corpus = """â“¢ Laulont â– â“• Francois â– â“‘ 8 -â“¢ Ciret â– â“• Antoine â– â“‘ 27 -â“¢ Ciret â– â“• Marie â– â“‘ 28 -â“¢ Ciret â– â“• Marie â– â“‘ 2 -â“¢ Eureston â– â“• Solange â– â“‘ 10 -â“¢ Terontussieux â– â“• Jean â– â“‘ 2 -â“¢ Pressonet â– â“• Marie â– â“‘ 12""" - - if existing: - expected_char_language_corpus = ( - f"{previous_character}\n" + expected_char_language_corpus - ) - expected_word_language_corpus = ( - f"{previous_character}\n" + expected_word_language_corpus - ) - expected_subword_language_corpus = ( - f"â– {previous_character}\n" + expected_subword_language_corpus - ) - - # Transcriptions with worker version are in lowercase - if transcription_entities_worker_version: - expected_char_language_corpus = expected_char_language_corpus.lower() - expected_word_language_corpus = expected_word_language_corpus.lower() - expected_subword_language_corpus = expected_subword_language_corpus.lower() - - # If we do not load entities, remove tokens - if not load_entities: - token_translations = {f"{token} ": "" for token in tokens} - expected_char_language_corpus = ENTITY_TOKEN_SPACE.sub( - "", expected_char_language_corpus - ) - expected_word_language_corpus = ENTITY_TOKEN_SPACE.sub( - "", expected_word_language_corpus - ) - expected_subword_language_corpus = ENTITY_TOKEN_SPACE.sub( - "", expected_subword_language_corpus - ) - # Replace double spaces with regular space - if not keep_spaces: - expected_char_language_corpus = TWO_SPACES_LM_REGEX.sub( - "â–", expected_char_language_corpus - ) - expected_word_language_corpus = TWO_SPACES_LM_REGEX.sub( - "â–", expected_word_language_corpus - ) - expected_subword_language_corpus = TWO_SPACES_LM_REGEX.sub( - "â–", expected_subword_language_corpus - ) - assert ( - output / "language_model" / "corpus_characters.txt" - ).read_text() == expected_char_language_corpus - - assert ( - output / "language_model" / "corpus_words.txt" - ).read_text() == expected_word_language_corpus - - assert ( - output / "language_model" / "corpus_subwords.txt" - ).read_text() == expected_subword_language_corpus - - # Check "language_tokens.txt" - expected_language_tokens = [ - "â–" if t.isspace() else t for t in sorted(list(expected_charset)) - ] - expected_language_tokens.append("â—Œ") - assert (output / "language_model" / "tokens.txt").read_text() == "\n".join( - expected_language_tokens - ) - - # Check "language_lexicon.txt" - expected_language_char_lexicon = [f"{t} {t}" for t in expected_language_tokens] - assert ( - output / "language_model" / "lexicon_characters.txt" - ).read_text() == "\n".join(expected_language_char_lexicon) - - word_vocab = set([word for word in expected_word_language_corpus.split()]) - expected_language_word_lexicon = [ - f"{word} {' '.join(word)}" for word in sorted(word_vocab) - ] - assert (output / "language_model" / "lexicon_words.txt").read_text() == "\n".join( - expected_language_word_lexicon + split_content, _ = change_split_content( + load_entities, + transcription_entities_worker_version, + keep_spaces, + split_content, + tokens, ) - subword_vocab = set( - [subword for subword in expected_subword_language_corpus.split()] - ) - expected_language_subword_lexicon = [ - f"{subword} {' '.join(subword)}" for subword in sorted(subword_vocab) - ] - assert ( - output / "language_model" / "lexicon_subwords.txt" - ).read_text() == "\n".join(expected_language_subword_lexicon) + assert json.loads((output / "split.json").read_text()) == split_content @pytest.mark.parametrize("allow_empty", (True, False)) -- GitLab