From 48a4fcebd886d0af31cd07f0c24c4a4463f8be66 Mon Sep 17 00:00:00 2001 From: Manon Blanco <blanco@teklia.com> Date: Fri, 19 Jan 2024 14:42:00 +0000 Subject: [PATCH] Load a language model and decode with it during evaluation --- dan/ocr/manager/metrics.py | 1 + dan/ocr/manager/training.py | 40 +++- dan/ocr/utils.py | 4 + docs/get_started/training.md | 2 +- docs/usage/predict/index.md | 15 +- docs/usage/train/config.md | 174 ++++++++++---- .../{language_lexicon.txt => lexicon.txt} | 0 tests/data/prediction/parameters.yml | 4 +- .../{language_tokens.txt => tokens.txt} | 0 tests/test_evaluate.py | 225 +++++++++++++++++- 10 files changed, 397 insertions(+), 68 deletions(-) rename tests/data/prediction/{language_lexicon.txt => lexicon.txt} (100%) rename tests/data/prediction/{language_tokens.txt => tokens.txt} (100%) diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py index 07cece06..393044ed 100644 --- a/dan/ocr/manager/metrics.py +++ b/dan/ocr/manager/metrics.py @@ -32,6 +32,7 @@ class Inference(NamedTuple): image: str ground_truth: str prediction: str + lm_prediction: str wer: float diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index 66f6c03b..a89dc845 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -21,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from dan.ocr.decoder import GlobalHTADecoder +from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder from dan.ocr.encoder import FCN_Encoder from dan.ocr.manager.metrics import Inference, MetricManager from dan.ocr.manager.ocr import OCRDatasetManager @@ -33,6 +33,7 @@ if MLFLOW_AVAILABLE: import mlflow logger = logging.getLogger(__name__) + MODEL_NAME_ENCODER = "encoder" MODEL_NAME_DECODER = "decoder" MODEL_NAMES = (MODEL_NAME_ENCODER, MODEL_NAME_DECODER) @@ -195,6 +196,28 @@ class GenericTrainingManager: output_device=self.ddp_config["rank"], ) + # Instantiate LM decoder + self.lm_decoder = None + if self.params["model"].get("lm") and self.params["model"]["lm"]["weight"] > 0: + logger.info( + f"Decoding with a language model (weight={self.params['model']['lm']['weight']})." + ) + # Check files + model_path = self.params["model"]["lm"]["path"] + assert model_path.is_file(), f"File {model_path} not found" + base_path = model_path.parent + lexicon_path = base_path / "lexicon.txt" + assert lexicon_path.is_file(), f"File {lexicon_path} not found" + tokens_path = base_path / "tokens.txt" + assert tokens_path.is_file(), f"File {tokens_path} not found" + # Load LM decoder + self.lm_decoder = CTCLanguageDecoder( + language_model_path=str(model_path), + lexicon_path=str(lexicon_path), + tokens_path=str(tokens_path), + language_model_weight=self.params["model"]["lm"]["weight"], + ) + # Handle curriculum dropout self.dropout_scheduler = DropoutScheduler(self.models) @@ -816,6 +839,7 @@ class GenericTrainingManager: batch_data["names"], batch_values["str_y"], batch_values["str_x"], + batch_values.get("str_lm", repeat("")), repeat(display_values["wer"]), ) ) @@ -1059,6 +1083,13 @@ class Manager(GenericTrainingManager): ) predicted_tokens_len = torch.ones((b,), dtype=torch.int, device=self.device) + # end token index will be used for ctc + tot_pred = torch.zeros( + (b, len(self.dataset.charset) + 1, max_chars), + dtype=torch.float, + device=self.device, + ) + whole_output = list() confidence_scores = list() cache = None @@ -1112,6 +1143,10 @@ class Manager(GenericTrainingManager): cache=cache, num_pred=1, ) + + # output total logit prediction + tot_pred[:, :, i : i + 1] = pred + whole_output.append(output) confidence_scores.append( torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values @@ -1158,4 +1193,7 @@ class Manager(GenericTrainingManager): "confidence_score": confidence_scores, "time": process_time, } + if self.lm_decoder: + values["str_lm"] = self.lm_decoder(tot_pred, prediction_len)["text"] + return values diff --git a/dan/ocr/utils.py b/dan/ocr/utils.py index cb6fe6e9..740102a6 100644 --- a/dan/ocr/utils.py +++ b/dan/ocr/utils.py @@ -37,6 +37,10 @@ def update_config(config: dict): # .model.decoder.class = GlobalHTADecoder config["model"]["decoder"]["class"] = GlobalHTADecoder + # .model.lm.path to Path + if config["model"].get("lm", {}).get("path"): + config["model"]["lm"]["path"] = Path(config["model"]["lm"]["path"]) + # Update preprocessing type for prepro in config["training"]["data"]["preprocessings"]: prepro["type"] = Preprocessing(prepro["type"]) diff --git a/docs/get_started/training.md b/docs/get_started/training.md index 93a4a38c..368a3b4d 100644 --- a/docs/get_started/training.md +++ b/docs/get_started/training.md @@ -41,4 +41,4 @@ To train a DAN model, please refer to the [documentation of the training command ## 3. Predict -Once the training is complete, you can apply a trained DAN model on an image using the [predict command](../usage/predict/index.md) and the `inference_parameters.yml` file, located in `{training.output_folder}/results`. +Once the training is complete, you can apply a trained DAN model on an image using the [predict command](../usage/predict/index.md) and the `inference_parameters.yml` file, located in `{training.output_folder}/results`. diff --git a/docs/usage/predict/index.md b/docs/usage/predict/index.md index 823399de..656011b1 100644 --- a/docs/usage/predict/index.md +++ b/docs/usage/predict/index.md @@ -166,14 +166,13 @@ It will create the following JSON file named after the image and a GIF showing a This example assumes that you have already [trained a language model](../train/language_model.md). -Note that: - -- the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions. -- linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks. +!!! note + - the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions. + - linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks. #### Language model at character level -First, update the `inference_parameters.yml` file obtained during DAN training. +Update the `parameters.yml` file obtained during DAN training. ```yaml parameters: @@ -185,8 +184,6 @@ parameters: weight: 0.5 ``` -Note that the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions. - Then, run this command: ```shell @@ -211,7 +208,7 @@ It will create the following JSON file named after the image in the `predict_cha #### Language model at subword level -Update the `inference_parameters.yml` file obtained during DAN training. +Update the `parameters.yml` file obtained during DAN training. ```yaml parameters: @@ -247,7 +244,7 @@ It will create the following JSON file named after the image in the `predict_sub #### Language model at word level -Update the `inference_parameters.yml` file obtained during DAN training. +Update the `parameters.yml` file obtained during DAN training. ```yaml parameters: diff --git a/docs/usage/train/config.md b/docs/usage/train/config.md index ac7e1d60..4408d385 100644 --- a/docs/usage/train/config.md +++ b/docs/usage/train/config.md @@ -17,59 +17,131 @@ To determine the value to use for `dataset.max_char_prediction`, you can use the ## Model parameters -| Name | Description | Type | Default | -| ----------------------------------- | ---------------------------------------------------------------------------------- | ------- | ------- | -| `model.transfered_charset` | Transfer learning of the decision layer based on charset of the model to transfer. | `bool` | `True` | -| `model.additional_tokens` | For decision layer = \[<eot>, \], only for transferred charset. | `int` | `1` | -| `model.encoder.dropout` | Dropout probability in the encoder. | `float` | `0.5` | -| `model.encoder.nb_layers` | Number of layers in the encoder. | `int` | `5` | -| `model.h_max` | Maximum height for encoder output (for 2D positional embedding). | `int` | `500` | -| `model.w_max` | Maximum width for encoder output (for 2D positional embedding). | `int` | `1000` | -| `model.decoder.enc_dim` | Dimension of features extracted by the encoder. | `int` | `256` | -| `model.decoder.l_max` | Maximum predicted sequence length (for 1D positional embedding). | `int` | `15000` | -| `model.decoder.dec_num_layers` | Number of transformer decoder layers. | `int` | `8` | -| `model.decoder.dec_num_heads` | Number of heads in transformer decoder layers. | `int` | `4` | -| `model.decoder.dec_res_dropout` | Dropout probability in transformer decoder layers. | `float` | `0.1` | -| `model.decoder.dec_pred_dropout` | Dropout rate before decision layer. | `float` | `0.1` | -| `model.decoder.dec_att_dropout` | Dropout rate in multi head attention. | `float` | `0.1` | -| `model.decoder.dec_dim_feedforward` | Number of dimensions for feedforward layer in transformer decoder layers. | `int` | `256` | -| `model.decoder.attention_win` | Length of attention window. | `int` | `100` | +| Name | Description | Type | Default | +| -------------------------- | ---------------------------------------------------------------------------------- | ------ | ------- | +| `model.transfered_charset` | Transfer learning of the decision layer based on charset of the model to transfer. | `bool` | `True` | +| `model.additional_tokens` | For decision layer = \[<eot>, \], only for transferred charset. | `int` | `1` | +| `model.h_max` | Maximum height for encoder output (for 2D positional embedding). | `int` | `500` | +| `model.w_max` | Maximum width for encoder output (for 2D positional embedding). | `int` | `1000` | + +### Encoder + +| Name | Description | Type | Default | +| ------------------------- | ----------------------------------- | ------- | ------- | +| `model.encoder.dropout` | Dropout probability in the encoder. | `float` | `0.5` | +| `model.encoder.nb_layers` | Number of layers in the encoder. | `int` | `5` | + +### Decoder + +| Name | Description | Type | Default | +| ----------------------------------- | ------------------------------------------------------------------------- | ------- | ------- | +| `model.decoder.enc_dim` | Dimension of features extracted by the encoder. | `int` | `256` | +| `model.decoder.l_max` | Maximum predicted sequence length (for 1D positional embedding). | `int` | `15000` | +| `model.decoder.dec_num_layers` | Number of transformer decoder layers. | `int` | `8` | +| `model.decoder.dec_num_heads` | Number of heads in transformer decoder layers. | `int` | `4` | +| `model.decoder.dec_res_dropout` | Dropout probability in transformer decoder layers. | `float` | `0.1` | +| `model.decoder.dec_pred_dropout` | Dropout rate before decision layer. | `float` | `0.1` | +| `model.decoder.dec_att_dropout` | Dropout rate in multi head attention. | `float` | `0.1` | +| `model.decoder.dec_dim_feedforward` | Number of dimensions for feedforward layer in transformer decoder layers. | `int` | `256` | +| `model.decoder.attention_win` | Length of attention window. | `int` | `100` | + +### Language model + +This assumes that you have already [trained a language model](../train/language_model.md). + +| Name | Description | Type | Default | +| ----------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | ------- | +| `model.lm.path` | Path to the language model. | `str` | | +| `model.lm.weight` | How much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions. | `float` | | + +!!! note + - linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks. + +The `model.lm.path` argument expects a path to the language mode, but the parent folder should also contains: + +- a `lexicon.txt` file, +- a `tokens.txt` file. + +You should get the following tree structure: + +``` +folder/ +├── <model.lm.path> # Path to the language model +├── lexicon.txt +└── tokens.txt +``` ## Training parameters -| Name | Description | Type | Default | -| ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------- | ------------ | --------------------------------------------------------------------------- | -| `training.data.batch_size` | Mini-batch size for the training loop. | `int` | `2` | -| `training.data.load_in_memory` | Load all images in CPU memory. | `bool` | `True` | -| `training.data.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` | -| `training.data.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) | -| `training.data.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#data-augmentation)) | -| `training.output_folder` | Directory for checkpoint and results. | `str` | | -| `training.max_nb_epochs` | Maximum number of epochs before stopping training. | `int` | `800` | -| `training.load_epoch` | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str` | `"last"` | -| `training.device.use_ddp` | Whether to use DistributedDataParallel. | `bool` | `False` | -| `training.device.ddp_port` | DDP port. | `int` | `20027` | -| `training.device.use_amp` | Whether to enable automatic mix-precision. | `bool` | `True` | -| `training.device.nb_gpu` | Number of GPUs to train DAN. Set to `null` to use all GPUs available. | `int` | | -| `training.device.force` | Use a specific device if available. Use `cpu` to train on CPU (for debugging) or `cuda`/`cuda:$gpu_device` to train on GPU. | `str` | | -| `training.optimizers.all.args.lr` | Learning rate for the optimizer. | `float` | `0.0001` | -| `training.optimizers.all.args.amsgrad` | Whether to use AMSGrad optimization. | `bool` | `False` | -| `training.lr_schedulers` | Learning rate schedulers. | custom class | | -| `training.validation.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` | -| `training.validation.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` | -| `training.validation.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | | -| `training.metrics.train` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` | -| `training.metrics.eval` | List of metrics to compute during validation. | `list` | `["cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` | -| `training.label_noise_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` | -| `training.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` | -| `training.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` | -| `training.transfer_learning.encoder` | Model to load for the encoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, True]` | -| `training.transfer_learning.decoder` | Model to load for the decoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, False]` | - -- To train on several GPUs, simply set the `training.use_ddp` parameter to `True`. By default, the model will use all available GPUs. To restrict access to fewer GPUs, one can modify the `training.nb_gpu` parameter. -- During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations. - -### Data preprocessing +| Name | Description | Type | Default | +| ------------------------ | --------------------------------------------------------------------------- | ------------ | -------- | +| `training.output_folder` | Directory for checkpoint and results. | `str` | | +| `training.max_nb_epochs` | Maximum number of epochs before stopping training. | `int` | `800` | +| `training.load_epoch` | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str` | `"last"` | +| `training.lr_schedulers` | Learning rate schedulers. | custom class | | + +### Device + +| Name | Description | Type | Default | +| -------------------------- | --------------------------------------------------------------------------------------------------------------------------- | ------ | ------- | +| `training.device.use_ddp` | Whether to use DistributedDataParallel. | `bool` | `False` | +| `training.device.ddp_port` | DDP port. | `int` | `20027` | +| `training.device.use_amp` | Whether to enable automatic mix-precision. | `bool` | `True` | +| `training.device.nb_gpu` | Number of GPUs to train DAN. Set to `null` to use all GPUs available. | `int` | | +| `training.device.force` | Use a specific device if available. Use `cpu` to train on CPU (for debugging) or `cuda`/`cuda:$gpu_device` to train on GPU. | `str` | | + +To train on several GPUs, simply set the `training.device.use_ddp` parameter to `True`. By default, the model will use all available GPUs. To restrict access to fewer GPUs, one can modify the `training.device.nb_gpu` parameter. + +### Optimizers + +| Name | Description | Type | Default | +| -------------------------------------- | ------------------------------------ | ------- | -------- | +| `training.optimizers.all.args.lr` | Learning rate for the optimizer. | `float` | `0.0001` | +| `training.optimizers.all.args.amsgrad` | Whether to use AMSGrad optimization. | `bool` | `False` | + +### Validation + +| Name | Description | Type | Default | +| -------------------------------------------- | -------------------------------------------------------------------------- | ------ | ------- | +| `training.validation.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` | +| `training.validation.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` | +| `training.validation.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | | + +During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations. + +### Metrics + +| Name | Description | Type | Default | +| ------------------------ | --------------------------------------------- | ------ | --------------------------------------------------------------------------- | +| `training.metrics.train` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` | +| `training.metrics.eval` | List of metrics to compute during validation. | `list` | `["cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` | + +### Label noise scheduler + +| Name | Description | Type | Default | +| ------------------------------------------------ | ------------------------------------------------ | ------- | ------- | +| `training.label_noise_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` | +| `training.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` | +| `training.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` | + +### Transfer learning + +| Name | Description | Type | Default | +| ------------------------------------ | -------------------------------------------------------------------------------------- | ------ | ----------------------------------------------------------------- | +| `training.transfer_learning.encoder` | Model to load for the encoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, True]` | +| `training.transfer_learning.decoder` | Model to load for the decoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["decoder", "pretrained_models/dan_rimes_page.pt", True, False]` | + +### Data + +| Name | Description | Type | Default | +| ------------------------------ | ---------------------------------------------------------- | ------ | ----------------------------------------------- | +| `training.data.batch_size` | Mini-batch size for the training loop. | `int` | `2` | +| `training.data.load_in_memory` | Load all images in CPU memory. | `bool` | `True` | +| `training.data.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` | +| `training.data.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#preprocessing)) | +| `training.data.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#augmentation)) | + +#### Preprocessing Preprocessing is applied before training the network (see the [dedicated references](../../ref/ocr/managers/dataset.md)). The list of accepted transforms is defined in the [dedicated references](../../ref/ocr/transforms.md#dan.ocr.transforms.Preprocessing). @@ -124,7 +196,7 @@ Usage: ] ``` -### Data augmentation +#### Augmentation Augmentation transformations are applied on-the-fly during training to artificially increase data variability. diff --git a/tests/data/prediction/language_lexicon.txt b/tests/data/prediction/lexicon.txt similarity index 100% rename from tests/data/prediction/language_lexicon.txt rename to tests/data/prediction/lexicon.txt diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml index 88a43d9c..0dd5aeed 100644 --- a/tests/data/prediction/parameters.yml +++ b/tests/data/prediction/parameters.yml @@ -24,6 +24,6 @@ parameters: max_width: 1500 language_model: model: tests/data/prediction/language_model.arpa - lexicon: tests/data/prediction/language_lexicon.txt - tokens: tests/data/prediction/language_tokens.txt + lexicon: tests/data/prediction/lexicon.txt + tokens: tests/data/prediction/tokens.txt weight: 1.0 diff --git a/tests/data/prediction/language_tokens.txt b/tests/data/prediction/tokens.txt similarity index 100% rename from tests/data/prediction/language_tokens.txt rename to tests/data/prediction/tokens.txt diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 0bdf5196..90f6ec50 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -8,9 +8,12 @@ import yaml from prettytable import PrettyTable from dan.ocr import evaluate +from dan.ocr.manager.metrics import Inference from dan.ocr.utils import add_metrics_table_row, create_metrics_table from tests import FIXTURES +PREDICTION_DATA_PATH = FIXTURES / "prediction" + def test_create_metrics_table(): metric_names = ["ignored", "wer", "cer", "time", "ner"] @@ -115,14 +118,228 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config): / f"predict_training-{split_name}_1685.yaml" ) - with filename.open() as f: + assert { + metric: value + for metric, value in yaml.safe_load(filename.read_bytes()).items() # Remove the times from the results as they vary - res = { + if "time" not in metric + } == expected_res + + # Remove results files + shutil.rmtree(evaluate_config["training"]["output_folder"] / "results") + + # Check the metrics Markdown table + captured_std = capsys.readouterr() + last_printed_lines = captured_std.out.split("\n")[10:] + assert ( + "\n".join(last_printed_lines) + == Path(FIXTURES / "evaluate" / "metrics_table.md").read_text() + ) + + +@pytest.mark.parametrize( + "language_model_weight, expected_inferences", + ( + ( + 0.0, + [ + ( + "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…12241", # Ground truth + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", # Prediction + "", # LM prediction + 0.125, # WER + ), + ( + "0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image + "ⓈTemplié â’»Marcelle â’·93 â“J â“€ch â“„E dachyle", # Ground truth + "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", # Prediction + "", # LM prediction + 0.2667, # WER + ), + ( + "2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image + "ⓈA â’»Charles â’·11 â“P â’¸C â“€F â“„A â“…14331", # Ground truth + "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", # Prediction + "", # LM prediction + 0.5, # WER + ), + ( + "ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image + "ⓈNaudin â’»Marie â’·53 â“S â’¸V â“€Belle mère", # Ground truth + "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", # Prediction + "", # LM prediction + 0.1429, # WER + ), + ], + ), + ( + 1.0, + [ + ( + "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…12241", # Ground truth + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", # Prediction + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", # LM prediction + 0.125, # WER + ), + ( + "0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image + "ⓈTemplié â’»Marcelle â’·93 â“J â“€ch â“„E dachyle", # Ground truth + "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", # Prediction + "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", # LM prediction + 0.2667, # WER + ), + ( + "2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image + "ⓈA â’»Charles â’·11 â“P â’¸C â“€F â“„A â“…14331", # Ground truth + "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", # Prediction + "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", # LM prediction + 0.5, # WER + ), + ( + "ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image + "ⓈNaudin â’»Marie â’·53 â“S â’¸V â“€Belle mère", # Ground truth + "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", # Prediction + "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", # LM prediction + 0.1429, # WER + ), + ], + ), + ( + 2.0, + [ + ( + "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…12241", # Ground truth + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", # Prediction + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", # LM prediction + 0.125, # WER + ), + ( + "0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image + "ⓈTemplié â’»Marcelle â’·93 â“J â“€ch â“„E dachyle", # Ground truth + "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", # Prediction + "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", # LM prediction + 0.2667, # WER + ), + ( + "2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image + "ⓈA â’»Charles â’·11 â“P â’¸C â“€F â“„A â“…14331", # Ground truth + "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", # Prediction + "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14331", # LM prediction + 0.5, # WER + ), + ( + "ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image + "ⓈNaudin â’»Marie â’·53 â“S â’¸V â“€Belle mère", # Ground truth + "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", # Prediction + "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", # LM prediction + 0.1429, # WER + ), + ], + ), + ), +) +def test_evaluate_language_model( + capsys, evaluate_config, language_model_weight, expected_inferences, monkeypatch +): + # LM predictions are never used/displayed + # We mock the `Inference` class to temporary check the results + global nb_inferences + nb_inferences = 0 + + class MockInference(Inference): + def __new__(cls, *args, **kwargs): + global nb_inferences + assert args == expected_inferences[nb_inferences] + nb_inferences += 1 + + return super().__new__(cls, *args, **kwargs) + + monkeypatch.setattr("dan.ocr.manager.training.Inference", MockInference) + + # Use the tmp_path as base folder + evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate" + + # Use a LM decoder + evaluate_config["model"]["lm"] = { + "path": PREDICTION_DATA_PATH / "language_model.arpa", + "weight": language_model_weight, + } + + evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD) + + # Check that the evaluation results are correct + for split_name, expected_res in [ + ( + "train", + { + "nb_chars": 90, + "cer": 0.1889, + "nb_chars_no_token": 76, + "cer_no_token": 0.2105, + "nb_words": 15, + "wer": 0.2667, + "nb_words_no_punct": 15, + "wer_no_punct": 0.2667, + "nb_words_no_token": 15, + "wer_no_token": 0.2667, + "nb_tokens": 14, + "ner": 0.0714, + "nb_samples": 2, + }, + ), + ( + "val", + { + "nb_chars": 34, + "cer": 0.0882, + "nb_chars_no_token": 26, + "cer_no_token": 0.1154, + "nb_words": 8, + "wer": 0.5, + "nb_words_no_punct": 8, + "wer_no_punct": 0.5, + "nb_words_no_token": 8, + "wer_no_token": 0.5, + "nb_tokens": 8, + "ner": 0.0, + "nb_samples": 1, + }, + ), + ( + "test", + { + "nb_chars": 36, + "cer": 0.0278, + "nb_chars_no_token": 30, + "cer_no_token": 0.0333, + "nb_words": 7, + "wer": 0.1429, + "nb_words_no_punct": 7, + "wer_no_punct": 0.1429, + "nb_words_no_token": 7, + "wer_no_token": 0.1429, + "nb_tokens": 6, + "ner": 0.0, + "nb_samples": 1, + }, + ), + ]: + filename = ( + evaluate_config["training"]["output_folder"] + / "results" + / f"predict_training-{split_name}_1685.yaml" + ) + + with filename.open() as f: + assert { metric: value for metric, value in yaml.safe_load(f).items() + # Remove the times from the results as they vary if "time" not in metric - } - assert res == expected_res + } == expected_res # Remove results files shutil.rmtree(evaluate_config["training"]["output_folder"] / "results") -- GitLab