Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Showing
with 471 additions and 118470 deletions
# Training configuration
To train a model, you need to write a JSON configuration file. The list of fields are described in the [next section](#dataset-parameters)
To train a model, you need to write a JSON configuration file. The list of fields are described in the [next section](#dataset-parameters).
An empty configuration file is available at `configs/quickstart.json`. You will need to fill in the paths.
## Dataset parameters
......@@ -8,7 +8,9 @@ An empty configuration file is available at `configs/quickstart.json`. You will
| Parameter | Description | Type | Default |
| ----------------------------- | --------------------------------------------------------------------------------------------------------------------- | -------------- | ------- |
| `dataset.max_char_prediction` | Maximum number of characters to predict. | `int` | `1000` |
| `dataset.tokens` | Path to a NER tokens configuration file similar to [the one used for extraction](../datasets/extract.md#description). | `pathlib.Path` | None |
| `dataset.tokens` | Path to a NER tokens configuration file similar to [the one used for extraction](../datasets/extract.md#description). | `pathlib.Path` | |
To determine the value to use for `dataset.max_char_prediction`, you can use the [analyze command](../datasets/analyze.md) to find the maximum number of characters in a label of the dataset.
!!! note
You must replace the pseudo-variables `$dataset_name` and `$dataset_path` with respectively the name and the relative/absolute path to your dataset.
......@@ -27,7 +29,7 @@ An empty configuration file is available at `configs/quickstart.json`. You will
| `model.decoder.l_max` | Maximum predicted sequence length (for 1D positional embedding). | `int` | `15000` |
| `model.decoder.dec_num_layers` | Number of transformer decoder layers. | `int` | `8` |
| `model.decoder.dec_num_heads` | Number of heads in transformer decoder layers. | `int` | `4` |
| `model.decoder.dec_res_dropout` | Dropout probability in transformer decoder layers. | `int` | `0.1` |
| `model.decoder.dec_res_dropout` | Dropout probability in transformer decoder layers. | `float` | `0.1` |
| `model.decoder.dec_pred_dropout` | Dropout rate before decision layer. | `float` | `0.1` |
| `model.decoder.dec_att_dropout` | Dropout rate in multi head attention. | `float` | `0.1` |
| `model.decoder.dec_dim_feedforward` | Number of dimensions for feedforward layer in transformer decoder layers. | `int` | `256` |
......@@ -48,11 +50,11 @@ An empty configuration file is available at `configs/quickstart.json`. You will
| `training.device.use_ddp` | Whether to use DistributedDataParallel. | `bool` | `False` |
| `training.device.ddp_port` | DDP port. | `int` | `20027` |
| `training.device.use_amp` | Whether to enable automatic mix-precision. | `bool` | `True` |
| `training.device.nb_gpu` | Number of GPUs to train DAN. Set to `null` to use all GPUs available. | `int` | `None` |
| `training.device.nb_gpu` | Number of GPUs to train DAN. Set to `null` to use all GPUs available. | `int` | |
| `training.device.force_cpu` | Whether to train on CPU (for debugging). | `bool` | `False` |
| `training.optimizers.all.args.lr` | Learning rate for the optimizer. | `float` | `0.0001` |
| `training.optimizers.all.args.amsgrad` | Whether to use AMSGrad optimization. | `bool` | `False` |
| `training.lr_schedulers` | Learning rate schedulers. | custom class | `None` |
| `training.lr_schedulers` | Learning rate schedulers. | custom class | |
| `training.validation.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` |
| `training.validation.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` |
| `training.validation.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | |
......@@ -69,17 +71,7 @@ An empty configuration file is available at `configs/quickstart.json`. You will
### Data preprocessing
Preprocessing is applied before training the network (see `dan/manager/dataset.py`). The list of accepted transforms is defined in `dan/transforms.py`:
```py
class Preprocessing(Enum):
# If the image is bigger than the given size, resize it while keeping the original ratio
MaxResize = "max_resize"
# Resize the height to a fixed value while keeping the original ratio
FixedHeightResize = "fixed_height_resize"
# Resize the width to a fixed value while keeping the original ratio
FixedWidthResize = "fixed_width_resize"
```
Preprocessing is applied before training the network (see the [dedicated references](../../ref/ocr/managers/dataset.md)). The list of accepted transforms is defined in the [dedicated references](../../ref/ocr/transforms.md#dan.ocr.transforms.Preprocessing).
Usage:
......@@ -182,9 +174,12 @@ $ pip install .[mlflow]
- update the following arguments:
| Name | Description | Type | Default |
| ------------------------------ | ------------------------------------ | ----- | ------- |
| `mlflow.run_id` | Name of the current run in MLflow. | `str` | |
| `mlflow.s3_endpoint_url` | URL of S3 endpoint. | `str` | |
| `mlflow.aws_access_key_id` | Access key id to the AWS server. | `str` | |
| `mlflow.aws_secret_access_key` | Secret access key to the AWS server. | `str` | |
| Name | Description | Type | Default |
| ------------------------------ | --------------------------------------- | ----- | ------- |
| `mlflow.run_id` | ID of the current run in MLflow. | `int` | |
| `mlflow.run_name` | Name of the current run in MLflow. | `str` | |
| `mlflow.s3_endpoint_url` | URL of S3 endpoint. | `str` | |
| `mlflow.tracking_uri` | URI of a tracking server. | `str` | |
| `mlflow.experiment_id` | ID of the current experiment in MLFlow. | `str` | |
| `mlflow.aws_access_key_id` | Access key ID to the AWS server. | `str` | |
| `mlflow.aws_secret_access_key` | Secret access key to the AWS server. | `str` | |
......@@ -4,11 +4,12 @@ Use the `teklia-dan train` command to train a new DAN model. It is able to train
To train DAN on your dataset:
1. Create a training JSON configuration file. Refer to the [dedicated section](parameters.md) for a description of parameters.
1. Create a training JSON configuration file. Refer to the [dedicated page](config.md) for a description of parameters.
1. Run `teklia-dan train --config path/to/your/config.json`.
1. Look into evaluation results in the output folder indicated in your configuration:
- `checkpoints` contains model weights for the last trained epoch and for the epoch giving the best valid CER.
- `results` contains the tensorboard log file, the parameters file, and the evaluation results for the best epoch.
1. (Optional) Train a language model. Refer to the [dedicated page](language_model.md).
## Additional pages
......
......@@ -43,7 +43,7 @@ teklia-dan train
## Train on multiple GPUs
To train on multiple GPUs, one needs to update the parameters in the training configuration file, as detailed in the [dedicated section](parameters.md#training-parameters). In addition, the number of GPUs required must be specified in the `train_dan.sh` file by updating the following line:
To train on multiple GPUs, one needs to update the parameters in the training configuration file, as detailed in the [dedicated page](config.md#training-parameters). In addition, the number of GPUs required must be specified in the `train_dan.sh` file by updating the following line:
```sh
#SBATCH --gres=gpu:<nb_gpus> # number of GPUs per node
......
......@@ -9,21 +9,25 @@ To build the language model, you first need to install and compile [kenlm](https
## Build the language model
The `teklia-dan dataset extract` automatically generate the files required to train the language model in `my_dataset/language_model/`.
The `teklia-dan dataset extract` automatically generate the files required to train a language model either at character, subword or word-level in `my_dataset/language_model/`.
Use the following command to build a 6-gram language model:
Note that linebreaks are replaced by spaces in the language model.
### Character-level
At character-level, we recommend building a 6-gram model. Use the following command:
```sh
bin/lmplz --order 6 \
--text my_dataset/language_model/corpus.txt \
--arpa my_dataset/language_model/model.arpa
--text my_dataset/language_model/corpus_characters.txt \
--arpa my_dataset/language_model/model_characters.arpa
```
The following message should be displayed if the language model was built successfully.
The following message should be displayed if the language model was built successfully:
```sh
=== 1/5 Counting and sorting n-grams ===
Reading my_dataset/language_model/corpus.txt
Reading language_model/corpus.txt
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
Unigram tokens 111629 types 109
......@@ -58,6 +62,26 @@ Chain sizes: 1:1308 2:27744 3:159140 4:412536 5:717920 6:1028896
Name:lmplz VmPeak:12643224 kB VmRSS:6344 kB RSSMax:1969316 kB user:0.196445 sys:0.514686 CPU:0.711161 real:0.682693
```
### Subord-level
At subword-level, we recommend building a 6-gram model. Use the following command:
```sh
bin/lmplz --order 6 \
--text my_dataset/language_model/corpus_subwords.txt \
--arpa my_dataset/language_model/model_subwords.arpa
```
### Word-level
At word-level, we recommend building a 3-gram model. Use the following command:
```sh
bin/lmplz --order 3 \
--text my_dataset/language_model/corpus_words.txt \
--arpa my_dataset/language_model/model_words.arpa
```
## Predict with a language model
See the [dedicated example](examples.md#predict-with-an-external-n-gram-language-model).
See the [dedicated example](../predict/index.md#predict-with-an-external-n-gram-language-model).
......@@ -61,18 +61,17 @@ nav:
- usage/index.md
- Datasets:
- usage/datasets/index.md
- Dataset extraction: usage/datasets/extract.md
- Dataset analysis: usage/datasets/analyze.md
- Dataset entities: usage/datasets/entities.md
- Dataset tokens: usage/datasets/tokens.md
- Dataset extraction: usage/datasets/extract.md
- Training:
- usage/train/index.md
- Parameters: usage/train/parameters.md
- Configuration: usage/train/config.md
- Data augmentation: usage/train/augmentation.md
- Language model: usage/train/language_model.md
- Jean Zay tutorial: usage/train/jeanzay.md
- Predict:
- usage/predict/index.md
- Train a language model: usage/predict/training_lm.md
- Parameters: usage/predict/parameters.md
- Examples: usage/predict/examples.md
- Predict: usage/predict/index.md
- Python Reference:
- Datasets:
......@@ -81,15 +80,18 @@ nav:
- Analyze:
- ref/datasets/analyze/index.md
- Statistics: ref/datasets/analyze/statistics.md
- Entities:
- ref/datasets/entities/index.md
- Extract: ref/datasets/entities/extract.md
- Extraction:
- ref/datasets/extract/index.md
- Arkindex: ref/datasets/extract/arkindex.md
- Utils: ref/datasets/extract/utils.md
- Database management: ref/datasets/extract/db.md
- Exceptions: ref/datasets/extract/exceptions.md
- Analysis:
- ref/datasets/analyze/index.md
- Statistics: ref/datasets/analyze/statistics.md
- Tokens:
- ref/datasets/tokens/index.md
- Generate: ref/datasets/tokens/generate.md
- OCR:
- ref/ocr/index.md
- Managers:
......@@ -100,13 +102,15 @@ nav:
- Training managers: ref/ocr/managers/training.md
- Training: ref/ocr/train.md
- Prediction:
- Inference: ref/ocr/predict/prediction.md
- ref/ocr/predict/index.md
- Inference: ref/ocr/predict/inference.md
- Attention: ref/ocr/predict/attention.md
- Decoder: ref/ocr/decoder.md
- Encoder: ref/ocr/encoder.md
- MLflow: ref/ocr/mlflow.md
- Schedulers: ref/ocr/schedulers.md
- Transformations: ref/ocr/transforms.md
- CLI: ref/cli.md
- Utils: ref/utils.md
markdown_extensions:
......
......@@ -6,11 +6,13 @@ flashlight-text==0.0.4
imageio==2.26.1
imagesize==1.4.1
mdutils==1.6.0
nltk==3.8.1
numpy==1.24.3
prettytable==3.8.0
PyYAML==6.0
scipy==1.10.1
teklia-line-image-extractor==0.2.8rc4
sentencepiece==0.1.99
teklia-line-image-extractor==0.2.8rc5
tenacity==8.2.3
tensorboard==2.12.2
torch==2.0.0
......
---
entities:
- birthdate
- firstname
- surname
File deleted
⎵ ⎵
▁ ▁
! !
" "
& &
......
This diff is collapsed.
!
"
&
......
---
birthdate:
start:
end:
firstname:
start:
end:
surname:
start:
end:
---
birthdate:
start:
end: ''
firstname:
start:
end: ''
surname:
start:
end: ''
# -*- coding: utf-8 -*-
from dan.datasets.entities.extract import run
from tests import FIXTURES
def test_entities(mock_database, tmp_path):
output_file = tmp_path / "entities.yml"
run(database=mock_database, output_file=output_file)
assert output_file.read_text() == (FIXTURES / "entities.yml").read_text()
......@@ -12,12 +12,12 @@ import pytest
from PIL import Image, ImageChops
from arkindex_export import Element, Transcription
from dan.datasets.extract.arkindex import IIIF_FULL_SIZE, ArkindexExtractor
from dan.datasets.extract.exceptions import (
NoEndTokenError,
NoTranscriptionError,
UnknownTokenInText,
)
from dan.datasets.extract.extract import IIIF_FULL_SIZE, ArkindexExtractor
from dan.datasets.extract.utils import (
EntityType,
download_image,
......@@ -33,7 +33,7 @@ EXTRACTION_DATA_PATH = FIXTURES / "extraction"
TWO_SPACES_REGEX = re.compile(r" {2}")
ENTITY_TOKEN_SPACE = re.compile(r"[ⓢ|ⓕ|ⓑ] ")
TWO_SPACES_LM_REGEX = re.compile(r"⎵ ⎵")
TWO_SPACES_LM_REGEX = re.compile(r"▁ ▁")
# NamedTuple to mock actual database result
Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str)
......@@ -319,19 +319,137 @@ def test_process_element_unknown_token_in_text_error(mock_database, tmp_path):
arkindex_extractor.process_element(element, "val")
@pytest.mark.parametrize("load_entities", (True, False))
@pytest.mark.parametrize("keep_spaces", (True, False))
# Transcription and entities have the same worker version
@pytest.mark.parametrize(
"transcription_entities_worker_version", ("worker_version_id", False)
"load_entities,keep_spaces,transcription_entities_worker_version,expected_subword_language_corpus,subword_vocab_size",
(
(
True,
True,
"worker_version_id",
"""▁ ⓢ c a i l l e t ▁ ⓕ m a u r i c e ▁ ⓑ 28. 9.0 6
▁ ⓢ re b ou l ▁ ⓕ j e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ b a re y re ▁ ⓕ j e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ r ou s s y ▁ ⓕ j e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ m a r i n ▁ ⓕ m a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ a m i c a l ▁ ⓕ e l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ b i r o s ▁ ⓕ m a e l ▁ ⓑ 30. 1 0 . 1 0""",
40,
),
(
True,
False,
"worker_version_id",
"""▁ ⓢ c a i l l e t ▁ ⓕ m a u r i c e ▁ ⓑ 28. 9.0 6
▁ ⓢ re b ou l ▁ ⓕ j e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ b a re y re ▁ ⓕ j e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ r ou s s y ▁ ⓕ j e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ m a r i n ▁ ⓕ m a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ a m i c a l ▁ ⓕ e l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ b i r o s ▁ ⓕ m a e l ▁ ⓑ 30. 1 0 . 1 0""",
40,
),
(
False,
True,
"worker_version_id",
"""▁ ca i l l e t ▁ ma u r i ce ▁ 28. 9.0 6
▁ re b o u l ▁ j e a n ▁ 30. 9.0 2
▁ b a re y re ▁ j e a n ▁ 28. 3 . 1 1
▁ r o u s s y ▁ j e a n ▁ 4 . 11.1 4
▁ ma r i n ▁ ma r ce l ▁ 10. 8 . 0 6
▁ a m i ca l ▁ el o i ▁ 11.1 0 . 0 4
▁ b i r o s ▁ ma el ▁ 30. 10. 1 0""",
40,
),
(
False,
False,
"worker_version_id",
"""▁ ca i l l e t ▁ ma u r i ce ▁ 28. 9.0 6
▁ re b o u l ▁ j e a n ▁ 30. 9.0 2
▁ b a re y re ▁ j e a n ▁ 28. 3 . 1 1
▁ r o u s s y ▁ j e a n ▁ 4 . 11.1 4
▁ ma r i n ▁ ma r ce l ▁ 10. 8 . 0 6
▁ a m i ca l ▁ el o i ▁ 11.1 0 . 0 4
▁ b i r o s ▁ ma el ▁ 30. 10. 1 0""",
40,
),
(
True,
True,
False,
"""▁ ⓢ C a i l l e t ▁ ⓕ M a u r i c e ▁ ⓑ 2 8 . 9 . 0 6
▁ ⓢ R e b o u l ▁ ⓕ J e a n ▁ ⓑ 3 0 . 9 . 0 2
▁ ⓢ B a r e y r e ▁ ⓕ J e a n ▁ ⓑ 2 8 . 3 . 1 1
▁ ⓢ R o u s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ M a r i n ▁ ⓕ M a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ A m i c a l ▁ ⓕ E l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 3 0 . 1 0 . 1 0""",
40,
),
(
True,
True,
False,
"""▁ ⓢ C a i l l e t ▁ ⓕ M a u ri ce ▁ ⓑ 28. 9.0 6
▁ ⓢ R e b ou l ▁ ⓕ J e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ B a re y re ▁ ⓕ J e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ R ou s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 11.1 4
▁ ⓢ Mar i n ▁ ⓕ Mar ce l ▁ ⓑ 10. 8 . 0 6
▁ ⓢ A m ic a l ▁ ⓕ E l o i ▁ ⓑ 11.1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 30. 10. 10""",
55,
),
(
True,
False,
False,
"""▁ ⓢ C a i l l e t ▁ ⓕ M a u r i c e ▁ ⓑ 2 8 . 9 . 0 6
▁ ⓢ R e b o u l ▁ ⓕ J e a n ▁ ⓑ 3 0 . 9 . 0 2
▁ ⓢ B a r e y r e ▁ ⓕ J e a n ▁ ⓑ 2 8 . 3 . 1 1
▁ ⓢ R o u s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ M a r i n ▁ ⓕ M a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ A m i c a l ▁ ⓕ E l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 3 0 . 1 0 . 1 0""",
40,
),
(
False,
True,
False,
"""▁ C a i l l e t ▁ Ma u r i c e ▁ 28. 9.0 6
▁ R e b o u l ▁ J e a n ▁ 30. 9.0 2
▁ B a r e y r e ▁ J e a n ▁ 28. 3 . 1 1
▁ R o u s s y ▁ J e a n ▁ 4 . 1 1 . 1 4
▁ Ma r i n ▁ Ma r c e l ▁ 1 0 . 8 . 0 6
▁ A m i c a l ▁ E l o i ▁ 1 1 . 1 0 . 0 4
▁ B i r o s ▁ Ma e l ▁ 30. 1 0 . 1 0""",
40,
),
(
False,
False,
False,
"""▁ C a i l l e t ▁ Ma u r i c e ▁ 28. 9.0 6
▁ R e b o u l ▁ J e a n ▁ 30. 9.0 2
▁ B a r e y r e ▁ J e a n ▁ 28. 3 . 1 1
▁ R o u s s y ▁ J e a n ▁ 4 . 1 1 . 1 4
▁ Ma r i n ▁ Ma r c e l ▁ 1 0 . 8 . 0 6
▁ A m i c a l ▁ E l o i ▁ 1 1 . 1 0 . 0 4
▁ B i r o s ▁ Ma e l ▁ 30. 1 0 . 1 0""",
40,
),
),
)
@patch("dan.datasets.extract.extract.download_image")
@patch("dan.datasets.extract.arkindex.download_image")
def test_extract(
mock_download_image,
load_entities,
keep_spaces,
transcription_entities_worker_version,
mock_database,
expected_subword_language_corpus,
subword_vocab_size,
tmp_path,
):
output = tmp_path / "extraction"
......@@ -362,6 +480,7 @@ def test_extract(
else None,
keep_spaces=keep_spaces,
image_extension=".jpg",
subword_vocab_size=subword_vocab_size,
)
# Mock build_image_url to simply return the path to the image
extractor.build_iiif_url = mock_build_image_url
......@@ -398,8 +517,14 @@ def test_extract(
VAL_DIR / "val-page_1-line_3.jpg",
output / "labels.json",
# Language resources
output / "language_model" / "corpus.txt",
output / "language_model" / "lexicon.txt",
output / "language_model" / "corpus_characters.txt",
output / "language_model" / "corpus_subwords.txt",
output / "language_model" / "corpus_words.txt",
output / "language_model" / "lexicon_characters.txt",
output / "language_model" / "lexicon_subwords.txt",
output / "language_model" / "lexicon_words.txt",
output / "language_model" / "subword_tokenizer.model",
output / "language_model" / "subword_tokenizer.vocab",
output / "language_model" / "tokens.txt",
]
assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths
......@@ -466,36 +591,67 @@ def test_extract(
assert set(pickle.loads((output / "charset.pkl").read_bytes())) == expected_charset
# Check "language_corpus.txt"
expected_language_corpus = """ⓢ C a i l l e t ⎵ ⎵ ⓕ M a u r i c e ⎵ ⎵ ⓑ 2 8 . 9 . 0 6
ⓢ R e b o u l ⎵ ⎵ ⓕ J e a n ⎵ ⎵ ⓑ 3 0 . 9 . 0 2
ⓢ B a r e y r e ⎵ ⎵ ⓕ J e a n ⎵ ⎵ ⓑ 2 8 . 3 . 1 1
ⓢ R o u s s y ⎵ ⎵ ⓕ J e a n ⎵ ⎵ ⓑ 4 . 1 1 . 1 4
ⓢ M a r i n ⎵ ⎵ ⓕ M a r c e l ⎵ ⎵ ⓑ 1 0 . 8 . 0 6
ⓢ A m i c a l ⎵ ⎵ ⓕ E l o i ⎵ ⎵ ⓑ 1 1 . 1 0 . 0 4
ⓢ B i r o s ⎵ ⎵ ⓕ M a e l ⎵ ⎵ ⓑ 3 0 . 1 0 . 1 0"""
expected_char_language_corpus = """ⓢ C a i l l e t ▁ ▁ ⓕ M a u r i c e ▁ ▁ ⓑ 2 8 . 9 . 0 6
ⓢ R e b o u l ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 3 0 . 9 . 0 2
ⓢ B a r e y r e ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 2 8 . 3 . 1 1
ⓢ R o u s s y ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 4 . 1 1 . 1 4
ⓢ M a r i n ▁ ▁ ⓕ M a r c e l ▁ ▁ ⓑ 1 0 . 8 . 0 6
ⓢ A m i c a l ▁ ▁ ⓕ E l o i ▁ ▁ ⓑ 1 1 . 1 0 . 0 4
ⓢ B i r o s ▁ ▁ ⓕ M a e l ▁ ▁ ⓑ 3 0 . 1 0 . 1 0"""
expected_word_language_corpus = """ⓢ Caillet ▁ ⓕ Maurice ▁ ⓑ 28 ▁ . ▁ 9 ▁ . ▁ 06
ⓢ Reboul ▁ ⓕ Jean ▁ ⓑ 30 ▁ . ▁ 9 ▁ . ▁ 02
ⓢ Bareyre ▁ ⓕ Jean ▁ ⓑ 28 ▁ . ▁ 3 ▁ . ▁ 11
ⓢ Roussy ▁ ⓕ Jean ▁ ⓑ 4 ▁ . ▁ 11 ▁ . ▁ 14
ⓢ Marin ▁ ⓕ Marcel ▁ ⓑ 10 ▁ . ▁ 8 ▁ . ▁ 06
ⓢ Amical ▁ ⓕ Eloi ▁ ⓑ 11 ▁ . ▁ 10 ▁ . ▁ 04
ⓢ Biros ▁ ⓕ Mael ▁ ⓑ 30 ▁ . ▁ 10 ▁ . ▁ 10"""
# Transcriptions with worker version are in lowercase
if transcription_entities_worker_version:
expected_language_corpus = expected_language_corpus.lower()
expected_char_language_corpus = expected_char_language_corpus.lower()
expected_word_language_corpus = expected_word_language_corpus.lower()
expected_subword_language_corpus = expected_subword_language_corpus.lower()
# If we do not load entities, remove tokens
if not load_entities:
token_translations = {f"{token} ": "" for token in tokens}
expected_language_corpus = ENTITY_TOKEN_SPACE.sub("", expected_language_corpus)
expected_char_language_corpus = ENTITY_TOKEN_SPACE.sub(
"", expected_char_language_corpus
)
expected_word_language_corpus = ENTITY_TOKEN_SPACE.sub(
"", expected_word_language_corpus
)
expected_subword_language_corpus = ENTITY_TOKEN_SPACE.sub(
"", expected_subword_language_corpus
)
# Replace double spaces with regular space
if not keep_spaces:
expected_language_corpus = TWO_SPACES_LM_REGEX.sub(
"", expected_language_corpus
expected_char_language_corpus = TWO_SPACES_LM_REGEX.sub(
"", expected_char_language_corpus
)
expected_word_language_corpus = TWO_SPACES_LM_REGEX.sub(
"", expected_word_language_corpus
)
expected_subword_language_corpus = TWO_SPACES_LM_REGEX.sub(
"", expected_subword_language_corpus
)
assert (
output / "language_model" / "corpus_characters.txt"
).read_text() == expected_char_language_corpus
assert (
output / "language_model" / "corpus_words.txt"
).read_text() == expected_word_language_corpus
assert (
output / "language_model" / "corpus.txt"
).read_text() == expected_language_corpus
output / "language_model" / "corpus_subwords.txt"
).read_text() == expected_subword_language_corpus
# Check "language_tokens.txt"
expected_language_tokens = [
t if t != " " else "" for t in sorted(list(expected_charset))
"" if t.isspace() else t for t in sorted(list(expected_charset))
]
expected_language_tokens.append("")
assert (output / "language_model" / "tokens.txt").read_text() == "\n".join(
......@@ -503,11 +659,29 @@ def test_extract(
)
# Check "language_lexicon.txt"
expected_language_lexicon = [f"{t} {t}" for t in expected_language_tokens]
assert (output / "language_model" / "lexicon.txt").read_text() == "\n".join(
expected_language_lexicon
expected_language_char_lexicon = [f"{t} {t}" for t in expected_language_tokens]
assert (
output / "language_model" / "lexicon_characters.txt"
).read_text() == "\n".join(expected_language_char_lexicon)
word_vocab = set([word for word in expected_word_language_corpus.split()])
expected_language_word_lexicon = [
f"{word} {' '.join(word)}" for word in sorted(word_vocab)
]
assert (output / "language_model" / "lexicon_words.txt").read_text() == "\n".join(
expected_language_word_lexicon
)
subword_vocab = set(
[subword for subword in expected_subword_language_corpus.split()]
)
expected_language_subword_lexicon = [
f"{subword} {' '.join(subword)}" for subword in sorted(subword_vocab)
]
assert (
output / "language_model" / "lexicon_subwords.txt"
).read_text() == "\n".join(expected_language_subword_lexicon)
# Check cropped images
for expected_path in expected_paths:
if expected_path.suffix != ".jpg":
......@@ -521,7 +695,7 @@ def test_extract(
)
@patch("dan.datasets.extract.extract.ArkindexExtractor.build_iiif_url")
@patch("dan.datasets.extract.arkindex.ArkindexExtractor.build_iiif_url")
def test_download_image_error(iiif_url, caplog, capsys):
task = {
"split": "train",
......
......@@ -8,8 +8,8 @@ import pytest
import yaml
from dan.ocr.predict.attention import Level
from dan.ocr.predict.prediction import DAN
from dan.ocr.predict.prediction import run as run_prediction
from dan.ocr.predict.inference import DAN
from dan.ocr.predict.inference import run as run_prediction
from dan.utils import parse_tokens, read_yaml
from tests import FIXTURES
......@@ -67,7 +67,11 @@ def test_predict(image_name, expected_prediction):
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
None,
1.0,
{"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"},
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
"confidences": {},
},
),
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
......@@ -75,6 +79,7 @@ def test_predict(image_name, expected_prediction):
1.0,
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
"confidences": {
"total": 1.0,
"word": [
......@@ -96,6 +101,7 @@ def test_predict(image_name, expected_prediction):
3.5,
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
"confidences": {
"total": 0.93,
"ner": [
......@@ -127,6 +133,7 @@ def test_predict(image_name, expected_prediction):
1.0,
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
"confidences": {
"total": 1.0,
"line": [
......@@ -144,6 +151,7 @@ def test_predict(image_name, expected_prediction):
3.5,
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
"confidences": {
"total": 0.93,
"ner": [
......@@ -169,7 +177,11 @@ def test_predict(image_name, expected_prediction):
"0dfe8bcd-ed0b-453e-bf19-cc697012296e",
None,
1.0,
{"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376"},
{
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
"language_model": {},
"confidences": {},
},
),
(
"0dfe8bcd-ed0b-453e-bf19-cc697012296e",
......@@ -177,6 +189,7 @@ def test_predict(image_name, expected_prediction):
1.0,
{
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
"language_model": {},
"confidences": {
"total": 1.0,
"ner": [
......@@ -260,13 +273,21 @@ def test_predict(image_name, expected_prediction):
"2c242f5c-e979-43c4-b6f2-a6d4815b651d",
False,
1.0,
{"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"},
{
"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
"language_model": {},
"confidences": {},
},
),
(
"ffdec445-7f14-4f5f-be44-68d0844d0df1",
False,
1.0,
{"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
{
"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
"language_model": {},
"confidences": {},
},
),
),
)
......@@ -315,7 +336,13 @@ def test_run_prediction(
["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
None,
1.0,
[{"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"}],
[
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
"confidences": {},
}
],
),
(
["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
......@@ -324,6 +351,7 @@ def test_run_prediction(
[
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
"confidences": {
"total": 1.0,
"word": [
......@@ -350,6 +378,7 @@ def test_run_prediction(
[
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
"confidences": {
"total": 1.0,
"ner": [
......@@ -376,6 +405,7 @@ def test_run_prediction(
},
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
"confidences": {
"total": 1.0,
"ner": [
......@@ -409,6 +439,7 @@ def test_run_prediction(
[
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
"confidences": {
"total": 1.0,
"word": [
......@@ -433,8 +464,16 @@ def test_run_prediction(
False,
1.0,
[
{"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"},
{"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
{
"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
"language_model": {},
"confidences": {},
},
{
"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
"language_model": {},
"confidences": {},
},
],
),
),
......
# -*- coding: utf-8 -*-
import pytest
from dan.datasets.tokens.generate import LIMIT, OFFSET, get_token, run
from tests import FIXTURES
TOKENS_DATA_PATH = FIXTURES / "tokens"
def test_get_token():
token_generator = get_token()
tokens = []
for _ in range(LIMIT - OFFSET):
tokens.append(next(token_generator))
assert tokens == [
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
]
@pytest.mark.parametrize(
"end_tokens, expected_file",
[
(True, TOKENS_DATA_PATH / "end_tokens.yml"),
(False, TOKENS_DATA_PATH / "no_end_tokens.yml"),
],
)
def test_tokens(end_tokens, expected_file, tmp_path):
output_file = tmp_path / "tokens.yml"
run(
entities=FIXTURES / "entities.yml",
end_tokens=end_tokens,
output_file=output_file,
)
assert output_file.read_text() == expected_file.read_text()