Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (2)
*gif filter=lfs diff=lfs merge=lfs -text
**/*.pt filter=lfs diff=lfs merge=lfs -text
tests/data/prediction/language_model.arpa filter=lfs diff=lfs merge=lfs -text
......@@ -147,6 +147,13 @@ def add_extract_parser(subcommands) -> None:
help="Images larger than this height will be resized to this height.",
)
parser.add_argument(
"--subword-vocab-size",
type=int,
default=1000,
help="Size of the vocabulary to train the sentencepiece subword tokenizer needed for language model.",
)
# Formatting arguments
parser.add_argument(
"--image-format",
......
......@@ -30,8 +30,10 @@ from dan.datasets.extract.exceptions import (
UnknownTokenInText,
)
from dan.datasets.extract.utils import (
Tokenizer,
download_image,
get_bbox,
get_vocabulary,
insert_token,
normalize_linebreaks,
normalize_spaces,
......@@ -77,6 +79,7 @@ class ArkindexExtractor:
keep_spaces: bool = False,
image_extension: str = "",
allow_empty: bool = False,
subword_vocab_size: int = 1000,
) -> None:
self.folders = folders
self.element_type = element_type
......@@ -92,14 +95,14 @@ class ArkindexExtractor:
self.image_extension = image_extension
self.allow_empty = allow_empty
self.mapping = LMTokenMapping()
self.keep_spaces = keep_spaces
self.subword_vocab_size = subword_vocab_size
self.data: Dict = defaultdict(dict)
self.charset = set()
self.language_corpus = []
self.language_corpus = defaultdict(list)
self.language_tokens = []
self.language_lexicon = []
self.language_lexicon = defaultdict(list)
# Image download tasks to process
self.tasks: List[Dict[str, str]] = []
......@@ -275,12 +278,6 @@ class ArkindexExtractor:
)
return text.strip()
def format_text_language_model(self, text: str):
"""
Format text for the language model. Return the text tokenized at character-level.
"""
return " ".join(map(self.mapping.encode_token, list(text.strip())))
def process_element(
self,
element: Element,
......@@ -319,10 +316,6 @@ class ArkindexExtractor:
self.data[split][str(image_path)] = text
self.charset = self.charset.union(set(text))
# Language model should be built using only text from the training set
if split == "train":
self.language_corpus.append(self.format_text_language_model(text))
def process_parent(
self,
pbar,
......@@ -361,6 +354,11 @@ class ArkindexExtractor:
"""
Convert charset to a LM-compatible charset. Ensure that special LM tokens do not appear in the charset.
"""
logger.info("Preparing language resources")
# Add unknown token to charset
self.charset.add(self.unknown_token)
# Build LM tokens
for token in sorted(list(self.charset)):
assert (
token not in self.mapping.encode.values()
......@@ -368,15 +366,40 @@ class ArkindexExtractor:
self.language_tokens.append(
self.mapping.encode[token]
) if token in self.mapping.encode else self.language_tokens.append(token)
# Add the special blank token
self.language_tokens.append(self.mapping.ctc.encoded)
# Build lexicon
assert all(
[len(token) == 1 for token in self.language_lexicon]
), "Tokens should be single characters."
self.language_lexicon = [f"{token} {token}" for token in self.language_tokens]
# Build LM corpus
train_corpus = [
text.replace(self.mapping.linebreak.display, self.mapping.space.display)
for text in self.data["train"].values()
]
tokenizer = Tokenizer(
training_corpus=train_corpus,
charset=self.language_tokens,
unknown_token=self.unknown_token,
outdir=self.output / "language_model",
mapping=self.mapping,
tokens=self.tokens,
subword_vocab_size=self.subword_vocab_size,
)
for level, tokenize in (
("characters", tokenizer.char_tokenize),
("words", tokenizer.word_tokenize),
("subwords", tokenizer.subword_tokenize),
):
self.language_corpus[level] = list(map(tokenize, train_corpus))
# Build LM lexicon
self.language_lexicon["characters"] = [
f"{token} {token}" for token in self.language_tokens
]
for level in ["words", "subwords"]:
self.language_lexicon[level] = [
f"{token} {tokenizer.char_tokenize(token)}"
for token in get_vocabulary(self.language_corpus[level])
]
def export(self):
(self.output / "labels.json").write_text(
......@@ -386,15 +409,16 @@ class ArkindexExtractor:
indent=4,
)
)
(self.output / "language_model" / "corpus.txt").write_text(
"\n".join(self.language_corpus)
)
for level in ["characters", "words", "subwords"]:
(self.output / "language_model" / f"corpus_{level}.txt").write_text(
"\n".join(self.language_corpus[level])
)
(self.output / "language_model" / f"lexicon_{level}.txt").write_text(
"\n".join(self.language_lexicon[level])
)
(self.output / "language_model" / "tokens.txt").write_text(
"\n".join(self.language_tokens)
)
(self.output / "language_model" / "lexicon.txt").write_text(
"\n".join(self.language_lexicon)
)
(self.output / "charset.pkl").write_bytes(
pickle.dumps(sorted(list(self.charset)))
)
......@@ -477,6 +501,7 @@ def run(
image_format: str,
keep_spaces: bool,
allow_empty: bool,
subword_vocab_size: int,
):
assert database.exists(), f"No file found @ {database}"
open_database(path=database)
......@@ -503,4 +528,5 @@ def run(
keep_spaces=keep_spaces,
image_extension=image_format,
allow_empty=allow_empty,
subword_vocab_size=subword_vocab_size,
).run()
# -*- coding: utf-8 -*-
import itertools
import logging
import operator
import re
from dataclasses import dataclass, field
from io import BytesIO
from typing import List
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Iterator, List, Optional, Union
import requests
import sentencepiece as spm
from nltk import wordpunct_tokenize
from PIL import Image, ImageOps
from tenacity import (
retry,
......@@ -13,7 +20,7 @@ from tenacity import (
wait_exponential,
)
from dan.utils import EntityType
from dan.utils import EntityType, LMTokenMapping
logger = logging.getLogger(__name__)
......@@ -117,3 +124,117 @@ def get_bbox(polygon: List[List[int]]) -> str:
x, y = min(all_x), min(all_y)
width, height = max(all_x) - x, max(all_y) - y
return ",".join(list(map(str, [int(x), int(y), int(width), int(height)])))
def get_vocabulary(tokenized_text: List[str]) -> set[str]:
"""
Compute set of vocabulary from tokenzied text.
:param tokenized_text: List of tokenized text.
"""
return sorted(set([token for doc in tokenized_text for token in doc.split()]))
@dataclass
class Tokenizer:
"""
A multi-level tokenizer (char, subword, word), where the subword tokenizer is trained using sentencepiece.
:param training_corpus: List of training text.
:param outdir: Path to save the subword tokenizer.
:param mapping: Mapping between displayed and encoded versions of special characters.
:param tokens: Start and end tokens used to represent named entities.
:param subword_vocab_size: Size of the vocabulary size to use to train the subword tokenizer.
"""
training_corpus: List[str]
charset: List[str]
unknown_token: str
outdir: Path
mapping: LMTokenMapping
tokens: Optional[EntityType] = None
subword_vocab_size: int = 1000
sentencepiece_model: spm.SentencePieceProcessor = field(init=False)
@property
def prefix(self):
return self.outdir / "subword_tokenizer"
@property
def ner_tokens(self) -> Union[List[str], Iterator[str]]:
if self.tokens is None:
return []
return itertools.chain(
map(operator.attrgetter("start"), self.tokens.values()),
filter(
operator.truth, map(operator.attrgetter("end"), self.tokens.values())
),
)
@property
def mapping_tokens(self) -> List[str]:
return [token.encoded for token in self.mapping]
@property
def special_tokens(self) -> List[str]:
return list(set(itertools.chain(self.mapping_tokens, self.ner_tokens)))
def __post_init__(self) -> None:
"""
Train a sentencepiece model on the training corpus.
"""
# Write the corpus in a text file
logger.info("Training a sentencepiece model for subword tokenization")
with NamedTemporaryFile(dir=self.outdir, suffix=".txt", mode="w") as tmp:
tmp.write("\n".join(self.training_corpus))
tmp.flush()
spm.SentencePieceTrainer.train(
input=tmp.name,
vocab_size=self.subword_vocab_size,
model_prefix=self.prefix,
user_defined_symbols=self.special_tokens,
)
# Load the model
self.sentencepiece_model = spm.SentencePieceProcessor(
model_file=str(self.prefix.with_suffix(".model"))
)
def subword_tokenize(self, text: str) -> str:
"""
Tokenize into subwords. Sampling is disabled to ensure reproducibility.
"""
tokens = self.sentencepiece_model.encode(text, out_type=str)
return " ".join(map("".join, map(self.encode, tokens)))
def word_tokenize(self, text: str) -> str:
"""
Tokenize text into a string of space-separated words. Spaces (⎵) and NER tokens are considered as words.
:param text: Text to be tokenized.
"""
words = list(map("".join, map(self.encode, wordpunct_tokenize(text))))
return " ".join(
[
word + f" {self.mapping.space.encoded}"
if (i != len(words) - 1 and word not in self.ner_tokens)
else word
for i, word in enumerate(words)
]
)
def char_tokenize(self, text: str) -> str:
"""
Tokenize text into a string of space-separated characters.
:param text: Text to be tokenized.
"""
return " ".join(
[
char if char in self.charset else self.unknown_token
for char in self.encode(text)
]
)
def encode(self, text: List[str]) -> List[str]:
"""
Encode special tokens.
:param text: Text to be encoded.
"""
return map(self.mapping.encode_token, text)
......@@ -22,7 +22,7 @@ class Token(NamedTuple):
class LMTokenMapping(NamedTuple):
space: Token = Token("", " ")
space: Token = Token("", " ")
linebreak: Token = Token("", "\n")
ctc: Token = Token("", "<ctc>")
......@@ -139,7 +139,9 @@ def parse_tokens(filename: str) -> Dict[str, EntityType]:
def read_yaml(yaml_path: str) -> Dict:
"""
Read YAML tokens file
Read YAML tokens file.
:param yaml_path: Path of the YAML file to read.
:return: The content of the read file.
"""
filename = Path(yaml_path)
assert filename.exists(), f"{yaml_path} does not resolve."
......@@ -152,6 +154,8 @@ def read_yaml(yaml_path: str) -> Dict:
def read_json(json_path: str) -> Dict:
"""
Read labels JSON file
:param json_path: Path of the JSON file to read.
:return: The content of the read file.
"""
filename = Path(json_path)
assert filename.exists(), f"{json_path} does not resolve."
......
......@@ -4,13 +4,15 @@ There are a several steps to follow when training a DAN model.
## 1. Extract data
The data must be extracted and formatted for training. To extract the data, DAN uses an Arkindex export database in SQLite format. You will need to:
To extract the data, DAN uses an Arkindex export database in SQLite format. You will need to:
1. Structure the data into folders (`train` / `val` / `test`) in [Arkindex](https://demo.arkindex.org/).
1. [Export the project](https://doc.arkindex.org/howto/export/) in SQLite format.
1. Extract the data with the [extract command](../usage/datasets/extract.md).
At the end, you should have a tree structure like this:
This command will extract and format the images and labels needed to train DAN. It will also tokenize the training corpus at character, subword, and word levels, allowing you to combine DAN with an explicit statistical language model to improve performance.
At the end, you should get the following tree structure:
```
output/
......@@ -20,10 +22,16 @@ output/
│ ├── train
│ ├── val
│ └── test
└── language_model
├── corpus.txt
├── lexicon.txt
└── tokens.txt
├── language_model
│ ├── corpus_characters.txt
│ ├── lexicon_characters.txt
│ ├── corpus_subwords.txt
│ ├── lexicon_subwords.txt
│ ├── corpus_words.txt
│ ├── lexicon_words.txt
│ ├── subword_tokenizer.model
│ ├── subword_tokenizer.vocab
│ └── tokens.txt
```
## 2. Train
......
......@@ -7,7 +7,7 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind
- Generate the images of each element (in the `images/` folder),
- Create the mapping of the images (identified by its path) to the ground-truth transcription (with NER tokens if needed) (in the `labels.json` file),
- Store the set of characters encountered in the dataset (in the `charset.pkl` file),
- Generate the resources needed to build a N-gram language model with [kenlm](https://github.com/kpu/kenlm) (in the `language_model/` folder).
- Generate the resources needed to build a n-gram language model at character, subword or word-level with [kenlm](https://github.com/kpu/kenlm) (in the `language_model/` folder).
If an image download fails for whatever reason, it won't appear in the transcriptions file. The reason will be printed to stdout at the end of the process. Before trying to download the image, it checks that it wasn't downloaded previously. It is thus safe to run this command twice if a few images failed.
......@@ -30,6 +30,7 @@ If an image download fails for whatever reason, it won't appear in the transcrip
| `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` |
| `--image-format` | Images will be saved under this format. | `str` | `.jpg` |
| `--allow-empty` | Elements with no transcriptions are skipped by default. This flag disables this behaviour. | `bool` | `False` |
| `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` |
The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively. This file can be generated by the `teklia-dan dataset tokens` command. More details in the [dedicated page](./tokens.md).
......
......@@ -174,38 +174,123 @@ It will create the following JSON file named `predict/example.json` and a GIF sh
This example assumes that you have already [trained a language model](../train/language_model.md).
First, update the `inference_parameters.yml` file obtained during DAN training. The `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions.
Note that:
- the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions.
- linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks.
#### Language model at character level
First, update the `inference_parameters.yml` file obtained during DAN training.
```yaml
parameters:
...
language_model:
model: language_model/model.arpa
lexicon: language_model/lexicon.txt
tokens: language_model/tokens.txt
model: my_dataset/language_model/model_characters.arpa
lexicon: my_dataset/language_model/lexicon_characters.txt
tokens: my_dataset/language_model/tokens.txt
weight: 0.5
```
Note that the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions.
Then, run this command:
```shell
teklia-dan predict \
--image example.jpg \
--model model.pt \
--parameters inference_parameters.yml \
--charset charset.pkl \
--image dan_humu_page/6e830f23-e70d-4399-8b94-f36ed3198575.jpg \
--model dan_humu_page/model.pt \
--parameters dan_humu_page/inference_parameters_char_lm.yml \
--charset dan_humu_page/charset.pkl \
--use-language-model \
--output predict/
--output dan_humu_page/predict_char_lm/
```
It will create the following JSON file named `predict/example.json`
It will create the following JSON file named `dan_humu_page/predict_char_lm/6e830f23-e70d-4399-8b94-f36ed3198575.json`
```json
{
"text": "etc., some jeg netop idag\nholder Vask paa.\nLeien af Skj\u00f8rterne\nbestad i at jeg kj\u00f8bte\net Forkl\u00e6de til hver\naf de to Piger, some\nhavde laant os dem.\nResten var Vask af Hardan-\ngerskj\u00f8rter og et Forkl\u00e6de,\nsamt Fragt paa det Gods\n(N\u00f8i) some man sendte\nmig ubet\u00e6lt.\nIdag fik jeg hyggeligt\nFrimarkebrev fra Fosvold\nMed Hilsen\nDeres\nHulda Garborg",
"language_model": {
"text": "eet., some jeg netop idag holder Vask paa. Leien af Skj\u00f8rterne bestad i at jeg kj\u00f8bte et Forkl\u00e6de til hver af de to Piger, some havde laant os dem. Resten var Vask af Hardan- gerskj\u00f8rter og et Forkl\u00e6de, samt Fragt paa det Gods (T\u00f8i) some man sendte mig ubet\u00e6lt. Idag fik jeg hyggeligt Frimarkebrev fra Fosvold Med Hilsen Deres Hulda Garborg",
"confidence": 0.9
}
}
```
#### Language model at subword level
Update the `inference_parameters.yml` file obtained during DAN training.
```yaml
parameters:
...
language_model:
model: my_dataset/language_model/model_subwords.arpa
lexicon: my_dataset/language_model/lexicon_subwords.txt
tokens: my_dataset/language_model/tokens.txt
weight: 0.5
```
Then, run this command:
```shell
teklia-dan predict \
--image dan_humu_page/6e830f23-e70d-4399-8b94-f36ed3198575.jpg \
--model dan_humu_page/model.pt \
--parameters dan_humu_page/inference_parameters_subword_lm.yml \
--charset dan_humu_page/charset.pkl \
--use-language-model \
--output dan_humu_page/predict_subword_lm/
```
It will create the following JSON file named `dan_humu_page/predict_subword_lm/6e830f23-e70d-4399-8b94-f36ed3198575.json`
```json
{
"text": "etc., some jeg netop idag\nholder Vask paa.\nLeien af Skj\u00f8rterne\nbestad i at jeg kj\u00f8bte\net Forkl\u00e6de til hver\naf de to Piger, some\nhavde laant os dem.\nResten var Vask af Hardan-\ngerskj\u00f8rter og et Forkl\u00e6de,\nsamt Fragt paa det Gods\n(N\u00f8i) some man sendte\nmig ubet\u00e6lt.\nIdag fik jeg hyggeligt\nFrimarkebrev fra Fosvold\nMed Hilsen\nDeres\nHulda Garborg",
"language_model": {
"text": "eet., some jeg netop idag holder Vask paa. Leien af Skj\u00f8rterne bestad i at jeg kj\u00f8bte et Forkl\u00e6de til hver af de to Piger, some havde laant os dem. Resten var Vask af Hardan- gerskj\u00f8rter og et Forkl\u00e6de, samt Fragt paa det Gods (T\u00f8i) some man sendte mig ubet\u00e6lt. Idag fik jeg hyggeligt Frim\u00e6rkebrev fra Fosvold Med Hilsen Deres Hulda Garborg",
"confidence": 0.84
}
}
```
#### Language model at word level
Update the `inference_parameters.yml` file obtained during DAN training.
```yaml
parameters:
...
language_model:
model: my_dataset/language_model/model_words.arpa
lexicon: my_dataset/language_model/lexicon_words.txt
tokens: my_dataset/language_model/tokens.txt
weight: 0.5
```
Then, run this command:
```shell
teklia-dan predict \
--image dan_humu_page/6e830f23-e70d-4399-8b94-f36ed3198575.jpg \
--model dan_humu_page/model.pt \
--parameters dan_humu_page/inference_parameters_word_lm.yml \
--charset dan_humu_page/charset.pkl \
--use-language-model \
--output dan_humu_page/predict_word_lm/
```
It will create the following JSON file named `dan_humu_page/predict_word_lm/6e830f23-e70d-4399-8b94-f36ed3198575.json`
```json
{
"text": "etc., some jeg netop idag\nholder Vask paa.\nLeien af Skj\u00f8rterne\nbestad i at jeg kj\u00f8bte\net Forkl\u00e6de til hver\naf de to Piger, some\nhavde laant os dem.\nResten var Vask af Hardan-\ngerskj\u00f8rter og et Forkl\u00e6de,\nsamt Fragt paa det Gods\n(N\u00f8i) some man sendte\nmig ubet\u00e6lt.\nIdag fik jeg hyggeligt\nFrimarkebrev fra Fosvold\nMed Hilsen\nDeres\nHulda Garborg",
"language_model": {
"text": "eet., some jeg netop idag\nholder Vask paa.\nLeien af Skj\u00f9rterne\nbestad i at jeg kj\u00f9bte\net Forkl\u00e7de til hver\naf de to Piger, some\nhavde laant os dem.\nResten var Vask af Hardan-\ngerskj\u00f9rter og et Forkl\u00e7de,\nsamt Fragt paa det Gods\n(N\u00f9i) some man sendte\nmig ubetalt.\nIdag fik jeg hyggeligt\nFrimarkebrev fra Fosvold\nMed Hilsen\nDeres\nHulda Garborg",
"confidence": 0.87
"text": "etc., some jeg netop idag holder Vask paa. Leien af Skj\u00f8rterne bestad i at jeg kj\u00f8bte et Forkl\u00e6de til hver af de to Piger, some havde laant os dem. Resten var Vask af Hardan- gerskj\u00f8rter og et Forkl\u00e6de, samt Fragt paa det Gods (T\u00f8i) some man sendte mig ubetalt. Idag fik jeg hyggeligt Frim\u00e6rkebrev fra Fosvold Med Hilsen Deres Hulda Garborg",
"confidence": 0.77
}
}
```
......@@ -9,14 +9,18 @@ To build the language model, you first need to install and compile [kenlm](https
## Build the language model
The `teklia-dan dataset extract` automatically generate the files required to train the language model in the `language_model/` folder.
The `teklia-dan dataset extract` automatically generate the files required to train a language model either at character, subword or word-level in `my_dataset/language_model/`.
Use the following command to build a 6-gram language model:
Note that linebreaks are replaced by spaces in the language model.
### Character-level
At character-level, we recommend building a 6-gram model. Use the following command:
```sh
bin/lmplz --order 6 \
--text language_model/corpus.txt \
--arpa language_model/model.arpa
--text my_dataset/language_model/corpus_characters.txt \
--arpa my_dataset/language_model/model_characters.arpa
```
The following message should be displayed if the language model was built successfully:
......@@ -57,3 +61,27 @@ Chain sizes: 1:1308 2:27744 3:159140 4:412536 5:717920 6:1028896
****************************************************************************************************
Name:lmplz VmPeak:12643224 kB VmRSS:6344 kB RSSMax:1969316 kB user:0.196445 sys:0.514686 CPU:0.711161 real:0.682693
```
### Subord-level
At subword-level, we recommend building a 6-gram model. Use the following command:
```sh
bin/lmplz --order 6 \
--text my_dataset/language_model/corpus_subwords.txt \
--arpa my_dataset/language_model/model_subwords.arpa
```
### Word-level
At word-level, we recommend building a 3-gram model. Use the following command:
```sh
bin/lmplz --order 3 \
--text my_dataset/language_model/corpus_words.txt \
--arpa my_dataset/language_model/model_words.arpa
```
## Predict with a language model
See the [dedicated example](../predict/index.md#predict-with-an-external-n-gram-language-model).
......@@ -6,11 +6,13 @@ flashlight-text==0.0.4
imageio==2.26.1
imagesize==1.4.1
mdutils==1.6.0
nltk==3.8.1
numpy==1.24.3
prettytable==3.8.0
PyYAML==6.0
scipy==1.10.1
teklia-line-image-extractor==0.2.8rc4
sentencepiece==0.1.99
teklia-line-image-extractor==0.2.8rc5
tenacity==8.2.3
tensorboard==2.12.2
torch==2.0.0
......
⎵ ⎵
▁ ▁
! !
" "
& &
......
This diff is collapsed.
!
"
&
......
......@@ -33,7 +33,7 @@ EXTRACTION_DATA_PATH = FIXTURES / "extraction"
TWO_SPACES_REGEX = re.compile(r" {2}")
ENTITY_TOKEN_SPACE = re.compile(r"[ⓢ|ⓕ|ⓑ] ")
TWO_SPACES_LM_REGEX = re.compile(r"⎵ ⎵")
TWO_SPACES_LM_REGEX = re.compile(r"▁ ▁")
# NamedTuple to mock actual database result
Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str)
......@@ -319,11 +319,127 @@ def test_process_element_unknown_token_in_text_error(mock_database, tmp_path):
arkindex_extractor.process_element(element, "val")
@pytest.mark.parametrize("load_entities", (True, False))
@pytest.mark.parametrize("keep_spaces", (True, False))
# Transcription and entities have the same worker version
@pytest.mark.parametrize(
"transcription_entities_worker_version", ("worker_version_id", False)
"load_entities,keep_spaces,transcription_entities_worker_version,expected_subword_language_corpus,subword_vocab_size",
(
(
True,
True,
"worker_version_id",
"""▁ ⓢ c a i l l e t ▁ ⓕ m a u r i c e ▁ ⓑ 28. 9.0 6
▁ ⓢ re b ou l ▁ ⓕ j e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ b a re y re ▁ ⓕ j e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ r ou s s y ▁ ⓕ j e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ m a r i n ▁ ⓕ m a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ a m i c a l ▁ ⓕ e l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ b i r o s ▁ ⓕ m a e l ▁ ⓑ 30. 1 0 . 1 0""",
40,
),
(
True,
False,
"worker_version_id",
"""▁ ⓢ c a i l l e t ▁ ⓕ m a u r i c e ▁ ⓑ 28. 9.0 6
▁ ⓢ re b ou l ▁ ⓕ j e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ b a re y re ▁ ⓕ j e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ r ou s s y ▁ ⓕ j e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ m a r i n ▁ ⓕ m a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ a m i c a l ▁ ⓕ e l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ b i r o s ▁ ⓕ m a e l ▁ ⓑ 30. 1 0 . 1 0""",
40,
),
(
False,
True,
"worker_version_id",
"""▁ ca i l l e t ▁ ma u r i ce ▁ 28. 9.0 6
▁ re b o u l ▁ j e a n ▁ 30. 9.0 2
▁ b a re y re ▁ j e a n ▁ 28. 3 . 1 1
▁ r o u s s y ▁ j e a n ▁ 4 . 11.1 4
▁ ma r i n ▁ ma r ce l ▁ 10. 8 . 0 6
▁ a m i ca l ▁ el o i ▁ 11.1 0 . 0 4
▁ b i r o s ▁ ma el ▁ 30. 10. 1 0""",
40,
),
(
False,
False,
"worker_version_id",
"""▁ ca i l l e t ▁ ma u r i ce ▁ 28. 9.0 6
▁ re b o u l ▁ j e a n ▁ 30. 9.0 2
▁ b a re y re ▁ j e a n ▁ 28. 3 . 1 1
▁ r o u s s y ▁ j e a n ▁ 4 . 11.1 4
▁ ma r i n ▁ ma r ce l ▁ 10. 8 . 0 6
▁ a m i ca l ▁ el o i ▁ 11.1 0 . 0 4
▁ b i r o s ▁ ma el ▁ 30. 10. 1 0""",
40,
),
(
True,
True,
False,
"""▁ ⓢ C a i l l e t ▁ ⓕ M a u r i c e ▁ ⓑ 2 8 . 9 . 0 6
▁ ⓢ R e b o u l ▁ ⓕ J e a n ▁ ⓑ 3 0 . 9 . 0 2
▁ ⓢ B a r e y r e ▁ ⓕ J e a n ▁ ⓑ 2 8 . 3 . 1 1
▁ ⓢ R o u s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ M a r i n ▁ ⓕ M a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ A m i c a l ▁ ⓕ E l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 3 0 . 1 0 . 1 0""",
40,
),
(
True,
True,
False,
"""▁ ⓢ C a i l l e t ▁ ⓕ M a u ri ce ▁ ⓑ 28. 9.0 6
▁ ⓢ R e b ou l ▁ ⓕ J e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ B a re y re ▁ ⓕ J e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ R ou s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 11.1 4
▁ ⓢ Mar i n ▁ ⓕ Mar ce l ▁ ⓑ 10. 8 . 0 6
▁ ⓢ A m ic a l ▁ ⓕ E l o i ▁ ⓑ 11.1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 30. 10. 10""",
55,
),
(
True,
False,
False,
"""▁ ⓢ C a i l l e t ▁ ⓕ M a u r i c e ▁ ⓑ 2 8 . 9 . 0 6
▁ ⓢ R e b o u l ▁ ⓕ J e a n ▁ ⓑ 3 0 . 9 . 0 2
▁ ⓢ B a r e y r e ▁ ⓕ J e a n ▁ ⓑ 2 8 . 3 . 1 1
▁ ⓢ R o u s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ M a r i n ▁ ⓕ M a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ A m i c a l ▁ ⓕ E l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 3 0 . 1 0 . 1 0""",
40,
),
(
False,
True,
False,
"""▁ C a i l l e t ▁ Ma u r i c e ▁ 28. 9.0 6
▁ R e b o u l ▁ J e a n ▁ 30. 9.0 2
▁ B a r e y r e ▁ J e a n ▁ 28. 3 . 1 1
▁ R o u s s y ▁ J e a n ▁ 4 . 1 1 . 1 4
▁ Ma r i n ▁ Ma r c e l ▁ 1 0 . 8 . 0 6
▁ A m i c a l ▁ E l o i ▁ 1 1 . 1 0 . 0 4
▁ B i r o s ▁ Ma e l ▁ 30. 1 0 . 1 0""",
40,
),
(
False,
False,
False,
"""▁ C a i l l e t ▁ Ma u r i c e ▁ 28. 9.0 6
▁ R e b o u l ▁ J e a n ▁ 30. 9.0 2
▁ B a r e y r e ▁ J e a n ▁ 28. 3 . 1 1
▁ R o u s s y ▁ J e a n ▁ 4 . 1 1 . 1 4
▁ Ma r i n ▁ Ma r c e l ▁ 1 0 . 8 . 0 6
▁ A m i c a l ▁ E l o i ▁ 1 1 . 1 0 . 0 4
▁ B i r o s ▁ Ma e l ▁ 30. 1 0 . 1 0""",
40,
),
),
)
@patch("dan.datasets.extract.arkindex.download_image")
def test_extract(
......@@ -332,6 +448,8 @@ def test_extract(
keep_spaces,
transcription_entities_worker_version,
mock_database,
expected_subword_language_corpus,
subword_vocab_size,
tmp_path,
):
output = tmp_path / "extraction"
......@@ -362,6 +480,7 @@ def test_extract(
else None,
keep_spaces=keep_spaces,
image_extension=".jpg",
subword_vocab_size=subword_vocab_size,
)
# Mock build_image_url to simply return the path to the image
extractor.build_iiif_url = mock_build_image_url
......@@ -398,8 +517,14 @@ def test_extract(
VAL_DIR / "val-page_1-line_3.jpg",
output / "labels.json",
# Language resources
output / "language_model" / "corpus.txt",
output / "language_model" / "lexicon.txt",
output / "language_model" / "corpus_characters.txt",
output / "language_model" / "corpus_subwords.txt",
output / "language_model" / "corpus_words.txt",
output / "language_model" / "lexicon_characters.txt",
output / "language_model" / "lexicon_subwords.txt",
output / "language_model" / "lexicon_words.txt",
output / "language_model" / "subword_tokenizer.model",
output / "language_model" / "subword_tokenizer.vocab",
output / "language_model" / "tokens.txt",
]
assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths
......@@ -466,36 +591,67 @@ def test_extract(
assert set(pickle.loads((output / "charset.pkl").read_bytes())) == expected_charset
# Check "language_corpus.txt"
expected_language_corpus = """ⓢ C a i l l e t ⎵ ⎵ ⓕ M a u r i c e ⎵ ⎵ ⓑ 2 8 . 9 . 0 6
ⓢ R e b o u l ⎵ ⎵ ⓕ J e a n ⎵ ⎵ ⓑ 3 0 . 9 . 0 2
ⓢ B a r e y r e ⎵ ⎵ ⓕ J e a n ⎵ ⎵ ⓑ 2 8 . 3 . 1 1
ⓢ R o u s s y ⎵ ⎵ ⓕ J e a n ⎵ ⎵ ⓑ 4 . 1 1 . 1 4
ⓢ M a r i n ⎵ ⎵ ⓕ M a r c e l ⎵ ⎵ ⓑ 1 0 . 8 . 0 6
ⓢ A m i c a l ⎵ ⎵ ⓕ E l o i ⎵ ⎵ ⓑ 1 1 . 1 0 . 0 4
ⓢ B i r o s ⎵ ⎵ ⓕ M a e l ⎵ ⎵ ⓑ 3 0 . 1 0 . 1 0"""
expected_char_language_corpus = """ⓢ C a i l l e t ▁ ▁ ⓕ M a u r i c e ▁ ▁ ⓑ 2 8 . 9 . 0 6
ⓢ R e b o u l ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 3 0 . 9 . 0 2
ⓢ B a r e y r e ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 2 8 . 3 . 1 1
ⓢ R o u s s y ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 4 . 1 1 . 1 4
ⓢ M a r i n ▁ ▁ ⓕ M a r c e l ▁ ▁ ⓑ 1 0 . 8 . 0 6
ⓢ A m i c a l ▁ ▁ ⓕ E l o i ▁ ▁ ⓑ 1 1 . 1 0 . 0 4
ⓢ B i r o s ▁ ▁ ⓕ M a e l ▁ ▁ ⓑ 3 0 . 1 0 . 1 0"""
expected_word_language_corpus = """ⓢ Caillet ▁ ⓕ Maurice ▁ ⓑ 28 ▁ . ▁ 9 ▁ . ▁ 06
ⓢ Reboul ▁ ⓕ Jean ▁ ⓑ 30 ▁ . ▁ 9 ▁ . ▁ 02
ⓢ Bareyre ▁ ⓕ Jean ▁ ⓑ 28 ▁ . ▁ 3 ▁ . ▁ 11
ⓢ Roussy ▁ ⓕ Jean ▁ ⓑ 4 ▁ . ▁ 11 ▁ . ▁ 14
ⓢ Marin ▁ ⓕ Marcel ▁ ⓑ 10 ▁ . ▁ 8 ▁ . ▁ 06
ⓢ Amical ▁ ⓕ Eloi ▁ ⓑ 11 ▁ . ▁ 10 ▁ . ▁ 04
ⓢ Biros ▁ ⓕ Mael ▁ ⓑ 30 ▁ . ▁ 10 ▁ . ▁ 10"""
# Transcriptions with worker version are in lowercase
if transcription_entities_worker_version:
expected_language_corpus = expected_language_corpus.lower()
expected_char_language_corpus = expected_char_language_corpus.lower()
expected_word_language_corpus = expected_word_language_corpus.lower()
expected_subword_language_corpus = expected_subword_language_corpus.lower()
# If we do not load entities, remove tokens
if not load_entities:
token_translations = {f"{token} ": "" for token in tokens}
expected_language_corpus = ENTITY_TOKEN_SPACE.sub("", expected_language_corpus)
expected_char_language_corpus = ENTITY_TOKEN_SPACE.sub(
"", expected_char_language_corpus
)
expected_word_language_corpus = ENTITY_TOKEN_SPACE.sub(
"", expected_word_language_corpus
)
expected_subword_language_corpus = ENTITY_TOKEN_SPACE.sub(
"", expected_subword_language_corpus
)
# Replace double spaces with regular space
if not keep_spaces:
expected_language_corpus = TWO_SPACES_LM_REGEX.sub(
"", expected_language_corpus
expected_char_language_corpus = TWO_SPACES_LM_REGEX.sub(
"", expected_char_language_corpus
)
expected_word_language_corpus = TWO_SPACES_LM_REGEX.sub(
"", expected_word_language_corpus
)
expected_subword_language_corpus = TWO_SPACES_LM_REGEX.sub(
"", expected_subword_language_corpus
)
assert (
output / "language_model" / "corpus_characters.txt"
).read_text() == expected_char_language_corpus
assert (
output / "language_model" / "corpus_words.txt"
).read_text() == expected_word_language_corpus
assert (
output / "language_model" / "corpus.txt"
).read_text() == expected_language_corpus
output / "language_model" / "corpus_subwords.txt"
).read_text() == expected_subword_language_corpus
# Check "language_tokens.txt"
expected_language_tokens = [
t if t != " " else "" for t in sorted(list(expected_charset))
"" if t.isspace() else t for t in sorted(list(expected_charset))
]
expected_language_tokens.append("")
assert (output / "language_model" / "tokens.txt").read_text() == "\n".join(
......@@ -503,11 +659,29 @@ def test_extract(
)
# Check "language_lexicon.txt"
expected_language_lexicon = [f"{t} {t}" for t in expected_language_tokens]
assert (output / "language_model" / "lexicon.txt").read_text() == "\n".join(
expected_language_lexicon
expected_language_char_lexicon = [f"{t} {t}" for t in expected_language_tokens]
assert (
output / "language_model" / "lexicon_characters.txt"
).read_text() == "\n".join(expected_language_char_lexicon)
word_vocab = set([word for word in expected_word_language_corpus.split()])
expected_language_word_lexicon = [
f"{word} {' '.join(word)}" for word in sorted(word_vocab)
]
assert (output / "language_model" / "lexicon_words.txt").read_text() == "\n".join(
expected_language_word_lexicon
)
subword_vocab = set(
[subword for subword in expected_subword_language_corpus.split()]
)
expected_language_subword_lexicon = [
f"{subword} {' '.join(subword)}" for subword in sorted(subword_vocab)
]
assert (
output / "language_model" / "lexicon_subwords.txt"
).read_text() == "\n".join(expected_language_subword_lexicon)
# Check cropped images
for expected_path in expected_paths:
if expected_path.suffix != ".jpg":
......