Skip to content
Snippets Groups Projects
Commit 2e719c9e authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Merge branch 'eval-load-lm' into 'main'

Load a language model and decode with it during evaluation

Closes #252

See merge request !347
parents 389c3505 48a4fceb
No related branches found
No related tags found
1 merge request!347Load a language model and decode with it during evaluation
......@@ -32,6 +32,7 @@ class Inference(NamedTuple):
image: str
ground_truth: str
prediction: str
lm_prediction: str
wer: float
......
......@@ -21,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.manager.metrics import Inference, MetricManager
from dan.ocr.manager.ocr import OCRDatasetManager
......@@ -33,6 +33,7 @@ if MLFLOW_AVAILABLE:
import mlflow
logger = logging.getLogger(__name__)
MODEL_NAME_ENCODER = "encoder"
MODEL_NAME_DECODER = "decoder"
MODEL_NAMES = (MODEL_NAME_ENCODER, MODEL_NAME_DECODER)
......@@ -195,6 +196,28 @@ class GenericTrainingManager:
output_device=self.ddp_config["rank"],
)
# Instantiate LM decoder
self.lm_decoder = None
if self.params["model"].get("lm") and self.params["model"]["lm"]["weight"] > 0:
logger.info(
f"Decoding with a language model (weight={self.params['model']['lm']['weight']})."
)
# Check files
model_path = self.params["model"]["lm"]["path"]
assert model_path.is_file(), f"File {model_path} not found"
base_path = model_path.parent
lexicon_path = base_path / "lexicon.txt"
assert lexicon_path.is_file(), f"File {lexicon_path} not found"
tokens_path = base_path / "tokens.txt"
assert tokens_path.is_file(), f"File {tokens_path} not found"
# Load LM decoder
self.lm_decoder = CTCLanguageDecoder(
language_model_path=str(model_path),
lexicon_path=str(lexicon_path),
tokens_path=str(tokens_path),
language_model_weight=self.params["model"]["lm"]["weight"],
)
# Handle curriculum dropout
self.dropout_scheduler = DropoutScheduler(self.models)
......@@ -816,6 +839,7 @@ class GenericTrainingManager:
batch_data["names"],
batch_values["str_y"],
batch_values["str_x"],
batch_values.get("str_lm", repeat("")),
repeat(display_values["wer"]),
)
)
......@@ -1059,6 +1083,13 @@ class Manager(GenericTrainingManager):
)
predicted_tokens_len = torch.ones((b,), dtype=torch.int, device=self.device)
# end token index will be used for ctc
tot_pred = torch.zeros(
(b, len(self.dataset.charset) + 1, max_chars),
dtype=torch.float,
device=self.device,
)
whole_output = list()
confidence_scores = list()
cache = None
......@@ -1112,6 +1143,10 @@ class Manager(GenericTrainingManager):
cache=cache,
num_pred=1,
)
# output total logit prediction
tot_pred[:, :, i : i + 1] = pred
whole_output.append(output)
confidence_scores.append(
torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
......@@ -1158,4 +1193,7 @@ class Manager(GenericTrainingManager):
"confidence_score": confidence_scores,
"time": process_time,
}
if self.lm_decoder:
values["str_lm"] = self.lm_decoder(tot_pred, prediction_len)["text"]
return values
......@@ -37,6 +37,10 @@ def update_config(config: dict):
# .model.decoder.class = GlobalHTADecoder
config["model"]["decoder"]["class"] = GlobalHTADecoder
# .model.lm.path to Path
if config["model"].get("lm", {}).get("path"):
config["model"]["lm"]["path"] = Path(config["model"]["lm"]["path"])
# Update preprocessing type
for prepro in config["training"]["data"]["preprocessings"]:
prepro["type"] = Preprocessing(prepro["type"])
......
......@@ -41,4 +41,4 @@ To train a DAN model, please refer to the [documentation of the training command
## 3. Predict
Once the training is complete, you can apply a trained DAN model on an image using the [predict command](../usage/predict/index.md) and the `inference_parameters.yml` file, located in `{training.output_folder}/results`.
Once the training is complete, you can apply a trained DAN model on an image using the [predict command](../usage/predict/index.md) and the `inference_parameters.yml` file, located in `{training.output_folder}/results`.
......@@ -166,14 +166,13 @@ It will create the following JSON file named after the image and a GIF showing a
This example assumes that you have already [trained a language model](../train/language_model.md).
Note that:
- the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions.
- linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks.
!!! note
- the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions.
- linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks.
#### Language model at character level
First, update the `inference_parameters.yml` file obtained during DAN training.
Update the `parameters.yml` file obtained during DAN training.
```yaml
parameters:
......@@ -185,8 +184,6 @@ parameters:
weight: 0.5
```
Note that the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions.
Then, run this command:
```shell
......@@ -211,7 +208,7 @@ It will create the following JSON file named after the image in the `predict_cha
#### Language model at subword level
Update the `inference_parameters.yml` file obtained during DAN training.
Update the `parameters.yml` file obtained during DAN training.
```yaml
parameters:
......@@ -247,7 +244,7 @@ It will create the following JSON file named after the image in the `predict_sub
#### Language model at word level
Update the `inference_parameters.yml` file obtained during DAN training.
Update the `parameters.yml` file obtained during DAN training.
```yaml
parameters:
......
......@@ -17,59 +17,131 @@ To determine the value to use for `dataset.max_char_prediction`, you can use the
## Model parameters
| Name | Description | Type | Default |
| ----------------------------------- | ---------------------------------------------------------------------------------- | ------- | ------- |
| `model.transfered_charset` | Transfer learning of the decision layer based on charset of the model to transfer. | `bool` | `True` |
| `model.additional_tokens` | For decision layer = \[<eot>, \], only for transferred charset. | `int` | `1` |
| `model.encoder.dropout` | Dropout probability in the encoder. | `float` | `0.5` |
| `model.encoder.nb_layers` | Number of layers in the encoder. | `int` | `5` |
| `model.h_max` | Maximum height for encoder output (for 2D positional embedding). | `int` | `500` |
| `model.w_max` | Maximum width for encoder output (for 2D positional embedding). | `int` | `1000` |
| `model.decoder.enc_dim` | Dimension of features extracted by the encoder. | `int` | `256` |
| `model.decoder.l_max` | Maximum predicted sequence length (for 1D positional embedding). | `int` | `15000` |
| `model.decoder.dec_num_layers` | Number of transformer decoder layers. | `int` | `8` |
| `model.decoder.dec_num_heads` | Number of heads in transformer decoder layers. | `int` | `4` |
| `model.decoder.dec_res_dropout` | Dropout probability in transformer decoder layers. | `float` | `0.1` |
| `model.decoder.dec_pred_dropout` | Dropout rate before decision layer. | `float` | `0.1` |
| `model.decoder.dec_att_dropout` | Dropout rate in multi head attention. | `float` | `0.1` |
| `model.decoder.dec_dim_feedforward` | Number of dimensions for feedforward layer in transformer decoder layers. | `int` | `256` |
| `model.decoder.attention_win` | Length of attention window. | `int` | `100` |
| Name | Description | Type | Default |
| -------------------------- | ---------------------------------------------------------------------------------- | ------ | ------- |
| `model.transfered_charset` | Transfer learning of the decision layer based on charset of the model to transfer. | `bool` | `True` |
| `model.additional_tokens` | For decision layer = \[<eot>, \], only for transferred charset. | `int` | `1` |
| `model.h_max` | Maximum height for encoder output (for 2D positional embedding). | `int` | `500` |
| `model.w_max` | Maximum width for encoder output (for 2D positional embedding). | `int` | `1000` |
### Encoder
| Name | Description | Type | Default |
| ------------------------- | ----------------------------------- | ------- | ------- |
| `model.encoder.dropout` | Dropout probability in the encoder. | `float` | `0.5` |
| `model.encoder.nb_layers` | Number of layers in the encoder. | `int` | `5` |
### Decoder
| Name | Description | Type | Default |
| ----------------------------------- | ------------------------------------------------------------------------- | ------- | ------- |
| `model.decoder.enc_dim` | Dimension of features extracted by the encoder. | `int` | `256` |
| `model.decoder.l_max` | Maximum predicted sequence length (for 1D positional embedding). | `int` | `15000` |
| `model.decoder.dec_num_layers` | Number of transformer decoder layers. | `int` | `8` |
| `model.decoder.dec_num_heads` | Number of heads in transformer decoder layers. | `int` | `4` |
| `model.decoder.dec_res_dropout` | Dropout probability in transformer decoder layers. | `float` | `0.1` |
| `model.decoder.dec_pred_dropout` | Dropout rate before decision layer. | `float` | `0.1` |
| `model.decoder.dec_att_dropout` | Dropout rate in multi head attention. | `float` | `0.1` |
| `model.decoder.dec_dim_feedforward` | Number of dimensions for feedforward layer in transformer decoder layers. | `int` | `256` |
| `model.decoder.attention_win` | Length of attention window. | `int` | `100` |
### Language model
This assumes that you have already [trained a language model](../train/language_model.md).
| Name | Description | Type | Default |
| ----------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | ------- |
| `model.lm.path` | Path to the language model. | `str` | |
| `model.lm.weight` | How much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions. | `float` | |
!!! note
- linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks.
The `model.lm.path` argument expects a path to the language mode, but the parent folder should also contains:
- a `lexicon.txt` file,
- a `tokens.txt` file.
You should get the following tree structure:
```
folder/
├── <model.lm.path> # Path to the language model
├── lexicon.txt
└── tokens.txt
```
## Training parameters
| Name | Description | Type | Default |
| ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------- | ------------ | --------------------------------------------------------------------------- |
| `training.data.batch_size` | Mini-batch size for the training loop. | `int` | `2` |
| `training.data.load_in_memory` | Load all images in CPU memory. | `bool` | `True` |
| `training.data.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` |
| `training.data.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) |
| `training.data.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#data-augmentation)) |
| `training.output_folder` | Directory for checkpoint and results. | `str` | |
| `training.max_nb_epochs` | Maximum number of epochs before stopping training. | `int` | `800` |
| `training.load_epoch` | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str` | `"last"` |
| `training.device.use_ddp` | Whether to use DistributedDataParallel. | `bool` | `False` |
| `training.device.ddp_port` | DDP port. | `int` | `20027` |
| `training.device.use_amp` | Whether to enable automatic mix-precision. | `bool` | `True` |
| `training.device.nb_gpu` | Number of GPUs to train DAN. Set to `null` to use all GPUs available. | `int` | |
| `training.device.force` | Use a specific device if available. Use `cpu` to train on CPU (for debugging) or `cuda`/`cuda:$gpu_device` to train on GPU. | `str` | |
| `training.optimizers.all.args.lr` | Learning rate for the optimizer. | `float` | `0.0001` |
| `training.optimizers.all.args.amsgrad` | Whether to use AMSGrad optimization. | `bool` | `False` |
| `training.lr_schedulers` | Learning rate schedulers. | custom class | |
| `training.validation.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` |
| `training.validation.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` |
| `training.validation.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | |
| `training.metrics.train` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` |
| `training.metrics.eval` | List of metrics to compute during validation. | `list` | `["cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` |
| `training.label_noise_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` |
| `training.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` |
| `training.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` |
| `training.transfer_learning.encoder` | Model to load for the encoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, True]` |
| `training.transfer_learning.decoder` | Model to load for the decoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, False]` |
- To train on several GPUs, simply set the `training.use_ddp` parameter to `True`. By default, the model will use all available GPUs. To restrict access to fewer GPUs, one can modify the `training.nb_gpu` parameter.
- During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations.
### Data preprocessing
| Name | Description | Type | Default |
| ------------------------ | --------------------------------------------------------------------------- | ------------ | -------- |
| `training.output_folder` | Directory for checkpoint and results. | `str` | |
| `training.max_nb_epochs` | Maximum number of epochs before stopping training. | `int` | `800` |
| `training.load_epoch` | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str` | `"last"` |
| `training.lr_schedulers` | Learning rate schedulers. | custom class | |
### Device
| Name | Description | Type | Default |
| -------------------------- | --------------------------------------------------------------------------------------------------------------------------- | ------ | ------- |
| `training.device.use_ddp` | Whether to use DistributedDataParallel. | `bool` | `False` |
| `training.device.ddp_port` | DDP port. | `int` | `20027` |
| `training.device.use_amp` | Whether to enable automatic mix-precision. | `bool` | `True` |
| `training.device.nb_gpu` | Number of GPUs to train DAN. Set to `null` to use all GPUs available. | `int` | |
| `training.device.force` | Use a specific device if available. Use `cpu` to train on CPU (for debugging) or `cuda`/`cuda:$gpu_device` to train on GPU. | `str` | |
To train on several GPUs, simply set the `training.device.use_ddp` parameter to `True`. By default, the model will use all available GPUs. To restrict access to fewer GPUs, one can modify the `training.device.nb_gpu` parameter.
### Optimizers
| Name | Description | Type | Default |
| -------------------------------------- | ------------------------------------ | ------- | -------- |
| `training.optimizers.all.args.lr` | Learning rate for the optimizer. | `float` | `0.0001` |
| `training.optimizers.all.args.amsgrad` | Whether to use AMSGrad optimization. | `bool` | `False` |
### Validation
| Name | Description | Type | Default |
| -------------------------------------------- | -------------------------------------------------------------------------- | ------ | ------- |
| `training.validation.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` |
| `training.validation.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` |
| `training.validation.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | |
During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations.
### Metrics
| Name | Description | Type | Default |
| ------------------------ | --------------------------------------------- | ------ | --------------------------------------------------------------------------- |
| `training.metrics.train` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` |
| `training.metrics.eval` | List of metrics to compute during validation. | `list` | `["cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` |
### Label noise scheduler
| Name | Description | Type | Default |
| ------------------------------------------------ | ------------------------------------------------ | ------- | ------- |
| `training.label_noise_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` |
| `training.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` |
| `training.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` |
### Transfer learning
| Name | Description | Type | Default |
| ------------------------------------ | -------------------------------------------------------------------------------------- | ------ | ----------------------------------------------------------------- |
| `training.transfer_learning.encoder` | Model to load for the encoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, True]` |
| `training.transfer_learning.decoder` | Model to load for the decoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["decoder", "pretrained_models/dan_rimes_page.pt", True, False]` |
### Data
| Name | Description | Type | Default |
| ------------------------------ | ---------------------------------------------------------- | ------ | ----------------------------------------------- |
| `training.data.batch_size` | Mini-batch size for the training loop. | `int` | `2` |
| `training.data.load_in_memory` | Load all images in CPU memory. | `bool` | `True` |
| `training.data.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` |
| `training.data.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#preprocessing)) |
| `training.data.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#augmentation)) |
#### Preprocessing
Preprocessing is applied before training the network (see the [dedicated references](../../ref/ocr/managers/dataset.md)). The list of accepted transforms is defined in the [dedicated references](../../ref/ocr/transforms.md#dan.ocr.transforms.Preprocessing).
......@@ -124,7 +196,7 @@ Usage:
]
```
### Data augmentation
#### Augmentation
Augmentation transformations are applied on-the-fly during training to artificially increase data variability.
......
......@@ -24,6 +24,6 @@ parameters:
max_width: 1500
language_model:
model: tests/data/prediction/language_model.arpa
lexicon: tests/data/prediction/language_lexicon.txt
tokens: tests/data/prediction/language_tokens.txt
lexicon: tests/data/prediction/lexicon.txt
tokens: tests/data/prediction/tokens.txt
weight: 1.0
......@@ -8,9 +8,12 @@ import yaml
from prettytable import PrettyTable
from dan.ocr import evaluate
from dan.ocr.manager.metrics import Inference
from dan.ocr.utils import add_metrics_table_row, create_metrics_table
from tests import FIXTURES
PREDICTION_DATA_PATH = FIXTURES / "prediction"
def test_create_metrics_table():
metric_names = ["ignored", "wer", "cer", "time", "ner"]
......@@ -115,14 +118,228 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
/ f"predict_training-{split_name}_1685.yaml"
)
with filename.open() as f:
assert {
metric: value
for metric, value in yaml.safe_load(filename.read_bytes()).items()
# Remove the times from the results as they vary
res = {
if "time" not in metric
} == expected_res
# Remove results files
shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")
# Check the metrics Markdown table
captured_std = capsys.readouterr()
last_printed_lines = captured_std.out.split("\n")[10:]
assert (
"\n".join(last_printed_lines)
== Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()
)
@pytest.mark.parametrize(
"language_model_weight, expected_inferences",
(
(
0.0,
[
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241", # Ground truth
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # Prediction
"", # LM prediction
0.125, # WER
),
(
"0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle", # Ground truth
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # Prediction
"", # LM prediction
0.2667, # WER
),
(
"2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image
"ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331", # Ground truth
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31", # Prediction
"", # LM prediction
0.5, # WER
),
(
"ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère", # Ground truth
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # Prediction
"", # LM prediction
0.1429, # WER
),
],
),
(
1.0,
[
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241", # Ground truth
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # Prediction
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # LM prediction
0.125, # WER
),
(
"0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle", # Ground truth
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # Prediction
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # LM prediction
0.2667, # WER
),
(
"2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image
"ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331", # Ground truth
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31", # Prediction
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31", # LM prediction
0.5, # WER
),
(
"ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère", # Ground truth
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # Prediction
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # LM prediction
0.1429, # WER
),
],
),
(
2.0,
[
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241", # Ground truth
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # Prediction
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # LM prediction
0.125, # WER
),
(
"0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle", # Ground truth
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # Prediction
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # LM prediction
0.2667, # WER
),
(
"2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image
"ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331", # Ground truth
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31", # Prediction
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14331", # LM prediction
0.5, # WER
),
(
"ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère", # Ground truth
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # Prediction
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # LM prediction
0.1429, # WER
),
],
),
),
)
def test_evaluate_language_model(
capsys, evaluate_config, language_model_weight, expected_inferences, monkeypatch
):
# LM predictions are never used/displayed
# We mock the `Inference` class to temporary check the results
global nb_inferences
nb_inferences = 0
class MockInference(Inference):
def __new__(cls, *args, **kwargs):
global nb_inferences
assert args == expected_inferences[nb_inferences]
nb_inferences += 1
return super().__new__(cls, *args, **kwargs)
monkeypatch.setattr("dan.ocr.manager.training.Inference", MockInference)
# Use the tmp_path as base folder
evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"
# Use a LM decoder
evaluate_config["model"]["lm"] = {
"path": PREDICTION_DATA_PATH / "language_model.arpa",
"weight": language_model_weight,
}
evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD)
# Check that the evaluation results are correct
for split_name, expected_res in [
(
"train",
{
"nb_chars": 90,
"cer": 0.1889,
"nb_chars_no_token": 76,
"cer_no_token": 0.2105,
"nb_words": 15,
"wer": 0.2667,
"nb_words_no_punct": 15,
"wer_no_punct": 0.2667,
"nb_words_no_token": 15,
"wer_no_token": 0.2667,
"nb_tokens": 14,
"ner": 0.0714,
"nb_samples": 2,
},
),
(
"val",
{
"nb_chars": 34,
"cer": 0.0882,
"nb_chars_no_token": 26,
"cer_no_token": 0.1154,
"nb_words": 8,
"wer": 0.5,
"nb_words_no_punct": 8,
"wer_no_punct": 0.5,
"nb_words_no_token": 8,
"wer_no_token": 0.5,
"nb_tokens": 8,
"ner": 0.0,
"nb_samples": 1,
},
),
(
"test",
{
"nb_chars": 36,
"cer": 0.0278,
"nb_chars_no_token": 30,
"cer_no_token": 0.0333,
"nb_words": 7,
"wer": 0.1429,
"nb_words_no_punct": 7,
"wer_no_punct": 0.1429,
"nb_words_no_token": 7,
"wer_no_token": 0.1429,
"nb_tokens": 6,
"ner": 0.0,
"nb_samples": 1,
},
),
]:
filename = (
evaluate_config["training"]["output_folder"]
/ "results"
/ f"predict_training-{split_name}_1685.yaml"
)
with filename.open() as f:
assert {
metric: value
for metric, value in yaml.safe_load(f).items()
# Remove the times from the results as they vary
if "time" not in metric
}
assert res == expected_res
} == expected_res
# Remove results files
shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment