diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py index 07cece06993cea371d6c4f3b34970ac63f8aa924..393044ede1e8ea520379b7b409aa6cd7aacce334 100644 --- a/dan/ocr/manager/metrics.py +++ b/dan/ocr/manager/metrics.py @@ -32,6 +32,7 @@ class Inference(NamedTuple): image: str ground_truth: str prediction: str + lm_prediction: str wer: float diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index 66f6c03bf217324e28dacde5d27c6bd91a444a42..a89dc845696a774e4f7b21eafc38339406ddce3c 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -21,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from dan.ocr.decoder import GlobalHTADecoder +from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder from dan.ocr.encoder import FCN_Encoder from dan.ocr.manager.metrics import Inference, MetricManager from dan.ocr.manager.ocr import OCRDatasetManager @@ -33,6 +33,7 @@ if MLFLOW_AVAILABLE: import mlflow logger = logging.getLogger(__name__) + MODEL_NAME_ENCODER = "encoder" MODEL_NAME_DECODER = "decoder" MODEL_NAMES = (MODEL_NAME_ENCODER, MODEL_NAME_DECODER) @@ -195,6 +196,28 @@ class GenericTrainingManager: output_device=self.ddp_config["rank"], ) + # Instantiate LM decoder + self.lm_decoder = None + if self.params["model"].get("lm") and self.params["model"]["lm"]["weight"] > 0: + logger.info( + f"Decoding with a language model (weight={self.params['model']['lm']['weight']})." + ) + # Check files + model_path = self.params["model"]["lm"]["path"] + assert model_path.is_file(), f"File {model_path} not found" + base_path = model_path.parent + lexicon_path = base_path / "lexicon.txt" + assert lexicon_path.is_file(), f"File {lexicon_path} not found" + tokens_path = base_path / "tokens.txt" + assert tokens_path.is_file(), f"File {tokens_path} not found" + # Load LM decoder + self.lm_decoder = CTCLanguageDecoder( + language_model_path=str(model_path), + lexicon_path=str(lexicon_path), + tokens_path=str(tokens_path), + language_model_weight=self.params["model"]["lm"]["weight"], + ) + # Handle curriculum dropout self.dropout_scheduler = DropoutScheduler(self.models) @@ -816,6 +839,7 @@ class GenericTrainingManager: batch_data["names"], batch_values["str_y"], batch_values["str_x"], + batch_values.get("str_lm", repeat("")), repeat(display_values["wer"]), ) ) @@ -1059,6 +1083,13 @@ class Manager(GenericTrainingManager): ) predicted_tokens_len = torch.ones((b,), dtype=torch.int, device=self.device) + # end token index will be used for ctc + tot_pred = torch.zeros( + (b, len(self.dataset.charset) + 1, max_chars), + dtype=torch.float, + device=self.device, + ) + whole_output = list() confidence_scores = list() cache = None @@ -1112,6 +1143,10 @@ class Manager(GenericTrainingManager): cache=cache, num_pred=1, ) + + # output total logit prediction + tot_pred[:, :, i : i + 1] = pred + whole_output.append(output) confidence_scores.append( torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values @@ -1158,4 +1193,7 @@ class Manager(GenericTrainingManager): "confidence_score": confidence_scores, "time": process_time, } + if self.lm_decoder: + values["str_lm"] = self.lm_decoder(tot_pred, prediction_len)["text"] + return values diff --git a/dan/ocr/utils.py b/dan/ocr/utils.py index cb6fe6e97ab7f60aaa171ad7e75453f70505ac50..740102a65fe4a2dec1d56256da3aecfa910e599e 100644 --- a/dan/ocr/utils.py +++ b/dan/ocr/utils.py @@ -37,6 +37,10 @@ def update_config(config: dict): # .model.decoder.class = GlobalHTADecoder config["model"]["decoder"]["class"] = GlobalHTADecoder + # .model.lm.path to Path + if config["model"].get("lm", {}).get("path"): + config["model"]["lm"]["path"] = Path(config["model"]["lm"]["path"]) + # Update preprocessing type for prepro in config["training"]["data"]["preprocessings"]: prepro["type"] = Preprocessing(prepro["type"]) diff --git a/docs/get_started/training.md b/docs/get_started/training.md index 93a4a38c4150323ebbb9b83b8c4cb68b412e3865..368a3b4dc84008d738b24715cda2be5ab0e116dc 100644 --- a/docs/get_started/training.md +++ b/docs/get_started/training.md @@ -41,4 +41,4 @@ To train a DAN model, please refer to the [documentation of the training command ## 3. Predict -Once the training is complete, you can apply a trained DAN model on an image using the [predict command](../usage/predict/index.md) and the `inference_parameters.yml` file, located in `{training.output_folder}/results`. +Once the training is complete, you can apply a trained DAN model on an image using the [predict command](../usage/predict/index.md) and the `inference_parameters.yml` file, located in `{training.output_folder}/results`. diff --git a/docs/usage/predict/index.md b/docs/usage/predict/index.md index 823399de80299417f9e9fa28a66ed0aaf67fe771..656011b1966cefc03de25ebbe99448ded57feeab 100644 --- a/docs/usage/predict/index.md +++ b/docs/usage/predict/index.md @@ -166,14 +166,13 @@ It will create the following JSON file named after the image and a GIF showing a This example assumes that you have already [trained a language model](../train/language_model.md). -Note that: - -- the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions. -- linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks. +!!! note + - the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions. + - linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks. #### Language model at character level -First, update the `inference_parameters.yml` file obtained during DAN training. +Update the `parameters.yml` file obtained during DAN training. ```yaml parameters: @@ -185,8 +184,6 @@ parameters: weight: 0.5 ``` -Note that the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions. - Then, run this command: ```shell @@ -211,7 +208,7 @@ It will create the following JSON file named after the image in the `predict_cha #### Language model at subword level -Update the `inference_parameters.yml` file obtained during DAN training. +Update the `parameters.yml` file obtained during DAN training. ```yaml parameters: @@ -247,7 +244,7 @@ It will create the following JSON file named after the image in the `predict_sub #### Language model at word level -Update the `inference_parameters.yml` file obtained during DAN training. +Update the `parameters.yml` file obtained during DAN training. ```yaml parameters: diff --git a/docs/usage/train/config.md b/docs/usage/train/config.md index ac7e1d60102bfcfc5bce97b6e24124590bdce402..4408d3852c2336c900f8ba509d15bc43f89c6fd6 100644 --- a/docs/usage/train/config.md +++ b/docs/usage/train/config.md @@ -17,59 +17,131 @@ To determine the value to use for `dataset.max_char_prediction`, you can use the ## Model parameters -| Name | Description | Type | Default | -| ----------------------------------- | ---------------------------------------------------------------------------------- | ------- | ------- | -| `model.transfered_charset` | Transfer learning of the decision layer based on charset of the model to transfer. | `bool` | `True` | -| `model.additional_tokens` | For decision layer = \[<eot>, \], only for transferred charset. | `int` | `1` | -| `model.encoder.dropout` | Dropout probability in the encoder. | `float` | `0.5` | -| `model.encoder.nb_layers` | Number of layers in the encoder. | `int` | `5` | -| `model.h_max` | Maximum height for encoder output (for 2D positional embedding). | `int` | `500` | -| `model.w_max` | Maximum width for encoder output (for 2D positional embedding). | `int` | `1000` | -| `model.decoder.enc_dim` | Dimension of features extracted by the encoder. | `int` | `256` | -| `model.decoder.l_max` | Maximum predicted sequence length (for 1D positional embedding). | `int` | `15000` | -| `model.decoder.dec_num_layers` | Number of transformer decoder layers. | `int` | `8` | -| `model.decoder.dec_num_heads` | Number of heads in transformer decoder layers. | `int` | `4` | -| `model.decoder.dec_res_dropout` | Dropout probability in transformer decoder layers. | `float` | `0.1` | -| `model.decoder.dec_pred_dropout` | Dropout rate before decision layer. | `float` | `0.1` | -| `model.decoder.dec_att_dropout` | Dropout rate in multi head attention. | `float` | `0.1` | -| `model.decoder.dec_dim_feedforward` | Number of dimensions for feedforward layer in transformer decoder layers. | `int` | `256` | -| `model.decoder.attention_win` | Length of attention window. | `int` | `100` | +| Name | Description | Type | Default | +| -------------------------- | ---------------------------------------------------------------------------------- | ------ | ------- | +| `model.transfered_charset` | Transfer learning of the decision layer based on charset of the model to transfer. | `bool` | `True` | +| `model.additional_tokens` | For decision layer = \[<eot>, \], only for transferred charset. | `int` | `1` | +| `model.h_max` | Maximum height for encoder output (for 2D positional embedding). | `int` | `500` | +| `model.w_max` | Maximum width for encoder output (for 2D positional embedding). | `int` | `1000` | + +### Encoder + +| Name | Description | Type | Default | +| ------------------------- | ----------------------------------- | ------- | ------- | +| `model.encoder.dropout` | Dropout probability in the encoder. | `float` | `0.5` | +| `model.encoder.nb_layers` | Number of layers in the encoder. | `int` | `5` | + +### Decoder + +| Name | Description | Type | Default | +| ----------------------------------- | ------------------------------------------------------------------------- | ------- | ------- | +| `model.decoder.enc_dim` | Dimension of features extracted by the encoder. | `int` | `256` | +| `model.decoder.l_max` | Maximum predicted sequence length (for 1D positional embedding). | `int` | `15000` | +| `model.decoder.dec_num_layers` | Number of transformer decoder layers. | `int` | `8` | +| `model.decoder.dec_num_heads` | Number of heads in transformer decoder layers. | `int` | `4` | +| `model.decoder.dec_res_dropout` | Dropout probability in transformer decoder layers. | `float` | `0.1` | +| `model.decoder.dec_pred_dropout` | Dropout rate before decision layer. | `float` | `0.1` | +| `model.decoder.dec_att_dropout` | Dropout rate in multi head attention. | `float` | `0.1` | +| `model.decoder.dec_dim_feedforward` | Number of dimensions for feedforward layer in transformer decoder layers. | `int` | `256` | +| `model.decoder.attention_win` | Length of attention window. | `int` | `100` | + +### Language model + +This assumes that you have already [trained a language model](../train/language_model.md). + +| Name | Description | Type | Default | +| ----------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | ------- | +| `model.lm.path` | Path to the language model. | `str` | | +| `model.lm.weight` | How much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions. | `float` | | + +!!! note + - linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks. + +The `model.lm.path` argument expects a path to the language mode, but the parent folder should also contains: + +- a `lexicon.txt` file, +- a `tokens.txt` file. + +You should get the following tree structure: + +``` +folder/ +├── <model.lm.path> # Path to the language model +├── lexicon.txt +└── tokens.txt +``` ## Training parameters -| Name | Description | Type | Default | -| ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------- | ------------ | --------------------------------------------------------------------------- | -| `training.data.batch_size` | Mini-batch size for the training loop. | `int` | `2` | -| `training.data.load_in_memory` | Load all images in CPU memory. | `bool` | `True` | -| `training.data.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` | -| `training.data.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#data-preprocessing)) | -| `training.data.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#data-augmentation)) | -| `training.output_folder` | Directory for checkpoint and results. | `str` | | -| `training.max_nb_epochs` | Maximum number of epochs before stopping training. | `int` | `800` | -| `training.load_epoch` | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str` | `"last"` | -| `training.device.use_ddp` | Whether to use DistributedDataParallel. | `bool` | `False` | -| `training.device.ddp_port` | DDP port. | `int` | `20027` | -| `training.device.use_amp` | Whether to enable automatic mix-precision. | `bool` | `True` | -| `training.device.nb_gpu` | Number of GPUs to train DAN. Set to `null` to use all GPUs available. | `int` | | -| `training.device.force` | Use a specific device if available. Use `cpu` to train on CPU (for debugging) or `cuda`/`cuda:$gpu_device` to train on GPU. | `str` | | -| `training.optimizers.all.args.lr` | Learning rate for the optimizer. | `float` | `0.0001` | -| `training.optimizers.all.args.amsgrad` | Whether to use AMSGrad optimization. | `bool` | `False` | -| `training.lr_schedulers` | Learning rate schedulers. | custom class | | -| `training.validation.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` | -| `training.validation.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` | -| `training.validation.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | | -| `training.metrics.train` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` | -| `training.metrics.eval` | List of metrics to compute during validation. | `list` | `["cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` | -| `training.label_noise_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` | -| `training.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` | -| `training.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` | -| `training.transfer_learning.encoder` | Model to load for the encoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, True]` | -| `training.transfer_learning.decoder` | Model to load for the decoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, False]` | - -- To train on several GPUs, simply set the `training.use_ddp` parameter to `True`. By default, the model will use all available GPUs. To restrict access to fewer GPUs, one can modify the `training.nb_gpu` parameter. -- During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations. - -### Data preprocessing +| Name | Description | Type | Default | +| ------------------------ | --------------------------------------------------------------------------- | ------------ | -------- | +| `training.output_folder` | Directory for checkpoint and results. | `str` | | +| `training.max_nb_epochs` | Maximum number of epochs before stopping training. | `int` | `800` | +| `training.load_epoch` | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str` | `"last"` | +| `training.lr_schedulers` | Learning rate schedulers. | custom class | | + +### Device + +| Name | Description | Type | Default | +| -------------------------- | --------------------------------------------------------------------------------------------------------------------------- | ------ | ------- | +| `training.device.use_ddp` | Whether to use DistributedDataParallel. | `bool` | `False` | +| `training.device.ddp_port` | DDP port. | `int` | `20027` | +| `training.device.use_amp` | Whether to enable automatic mix-precision. | `bool` | `True` | +| `training.device.nb_gpu` | Number of GPUs to train DAN. Set to `null` to use all GPUs available. | `int` | | +| `training.device.force` | Use a specific device if available. Use `cpu` to train on CPU (for debugging) or `cuda`/`cuda:$gpu_device` to train on GPU. | `str` | | + +To train on several GPUs, simply set the `training.device.use_ddp` parameter to `True`. By default, the model will use all available GPUs. To restrict access to fewer GPUs, one can modify the `training.device.nb_gpu` parameter. + +### Optimizers + +| Name | Description | Type | Default | +| -------------------------------------- | ------------------------------------ | ------- | -------- | +| `training.optimizers.all.args.lr` | Learning rate for the optimizer. | `float` | `0.0001` | +| `training.optimizers.all.args.amsgrad` | Whether to use AMSGrad optimization. | `bool` | `False` | + +### Validation + +| Name | Description | Type | Default | +| -------------------------------------------- | -------------------------------------------------------------------------- | ------ | ------- | +| `training.validation.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` | +| `training.validation.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` | +| `training.validation.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | | + +During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations. + +### Metrics + +| Name | Description | Type | Default | +| ------------------------ | --------------------------------------------- | ------ | --------------------------------------------------------------------------- | +| `training.metrics.train` | List of metrics to compute during training. | `list` | `["loss_ce", "cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` | +| `training.metrics.eval` | List of metrics to compute during validation. | `list` | `["cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token"]` | + +### Label noise scheduler + +| Name | Description | Type | Default | +| ------------------------------------------------ | ------------------------------------------------ | ------- | ------- | +| `training.label_noise_scheduler.min_error_rate` | Minimum ratio of teacher forcing. | `float` | `0.2` | +| `training.label_noise_scheduler.max_error_rate` | Maximum ratio of teacher forcing. | `float` | `0.2` | +| `training.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing. | `float` | `5e4` | + +### Transfer learning + +| Name | Description | Type | Default | +| ------------------------------------ | -------------------------------------------------------------------------------------- | ------ | ----------------------------------------------------------------- | +| `training.transfer_learning.encoder` | Model to load for the encoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["encoder", "pretrained_models/dan_rimes_page.pt", True, True]` | +| `training.transfer_learning.decoder` | Model to load for the decoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list` | `["decoder", "pretrained_models/dan_rimes_page.pt", True, False]` | + +### Data + +| Name | Description | Type | Default | +| ------------------------------ | ---------------------------------------------------------- | ------ | ----------------------------------------------- | +| `training.data.batch_size` | Mini-batch size for the training loop. | `int` | `2` | +| `training.data.load_in_memory` | Load all images in CPU memory. | `bool` | `True` | +| `training.data.worker_per_gpu` | Number of parallel processes per gpu for data loading. | `int` | `4` | +| `training.data.preprocessings` | List of pre-processing functions to apply to input images. | `list` | (see [dedicated section](#preprocessing)) | +| `training.data.augmentation` | Whether to use data augmentation on the training set. | `bool` | `True` (see [dedicated section](#augmentation)) | + +#### Preprocessing Preprocessing is applied before training the network (see the [dedicated references](../../ref/ocr/managers/dataset.md)). The list of accepted transforms is defined in the [dedicated references](../../ref/ocr/transforms.md#dan.ocr.transforms.Preprocessing). @@ -124,7 +196,7 @@ Usage: ] ``` -### Data augmentation +#### Augmentation Augmentation transformations are applied on-the-fly during training to artificially increase data variability. diff --git a/tests/data/prediction/language_lexicon.txt b/tests/data/prediction/lexicon.txt similarity index 100% rename from tests/data/prediction/language_lexicon.txt rename to tests/data/prediction/lexicon.txt diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml index 88a43d9c171ee483e734f01b0d99250f10c5d5ec..0dd5aeed588069f02fd9c89066813246524c71e3 100644 --- a/tests/data/prediction/parameters.yml +++ b/tests/data/prediction/parameters.yml @@ -24,6 +24,6 @@ parameters: max_width: 1500 language_model: model: tests/data/prediction/language_model.arpa - lexicon: tests/data/prediction/language_lexicon.txt - tokens: tests/data/prediction/language_tokens.txt + lexicon: tests/data/prediction/lexicon.txt + tokens: tests/data/prediction/tokens.txt weight: 1.0 diff --git a/tests/data/prediction/language_tokens.txt b/tests/data/prediction/tokens.txt similarity index 100% rename from tests/data/prediction/language_tokens.txt rename to tests/data/prediction/tokens.txt diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 0bdf51968cf985a7c1fa1dcac00aedbe4b474091..90f6ec50f26611ff43a56eef56e7c6fa822e87fa 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -8,9 +8,12 @@ import yaml from prettytable import PrettyTable from dan.ocr import evaluate +from dan.ocr.manager.metrics import Inference from dan.ocr.utils import add_metrics_table_row, create_metrics_table from tests import FIXTURES +PREDICTION_DATA_PATH = FIXTURES / "prediction" + def test_create_metrics_table(): metric_names = ["ignored", "wer", "cer", "time", "ner"] @@ -115,14 +118,228 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config): / f"predict_training-{split_name}_1685.yaml" ) - with filename.open() as f: + assert { + metric: value + for metric, value in yaml.safe_load(filename.read_bytes()).items() # Remove the times from the results as they vary - res = { + if "time" not in metric + } == expected_res + + # Remove results files + shutil.rmtree(evaluate_config["training"]["output_folder"] / "results") + + # Check the metrics Markdown table + captured_std = capsys.readouterr() + last_printed_lines = captured_std.out.split("\n")[10:] + assert ( + "\n".join(last_printed_lines) + == Path(FIXTURES / "evaluate" / "metrics_table.md").read_text() + ) + + +@pytest.mark.parametrize( + "language_model_weight, expected_inferences", + ( + ( + 0.0, + [ + ( + "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…12241", # Ground truth + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", # Prediction + "", # LM prediction + 0.125, # WER + ), + ( + "0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image + "ⓈTemplié â’»Marcelle â’·93 â“J â“€ch â“„E dachyle", # Ground truth + "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", # Prediction + "", # LM prediction + 0.2667, # WER + ), + ( + "2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image + "ⓈA â’»Charles â’·11 â“P â’¸C â“€F â“„A â“…14331", # Ground truth + "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", # Prediction + "", # LM prediction + 0.5, # WER + ), + ( + "ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image + "ⓈNaudin â’»Marie â’·53 â“S â’¸V â“€Belle mère", # Ground truth + "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", # Prediction + "", # LM prediction + 0.1429, # WER + ), + ], + ), + ( + 1.0, + [ + ( + "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…12241", # Ground truth + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", # Prediction + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", # LM prediction + 0.125, # WER + ), + ( + "0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image + "ⓈTemplié â’»Marcelle â’·93 â“J â“€ch â“„E dachyle", # Ground truth + "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", # Prediction + "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", # LM prediction + 0.2667, # WER + ), + ( + "2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image + "ⓈA â’»Charles â’·11 â“P â’¸C â“€F â“„A â“…14331", # Ground truth + "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", # Prediction + "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", # LM prediction + 0.5, # WER + ), + ( + "ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image + "ⓈNaudin â’»Marie â’·53 â“S â’¸V â“€Belle mère", # Ground truth + "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", # Prediction + "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", # LM prediction + 0.1429, # WER + ), + ], + ), + ( + 2.0, + [ + ( + "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…12241", # Ground truth + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", # Prediction + "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", # LM prediction + 0.125, # WER + ), + ( + "0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image + "ⓈTemplié â’»Marcelle â’·93 â“J â“€ch â“„E dachyle", # Ground truth + "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", # Prediction + "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", # LM prediction + 0.2667, # WER + ), + ( + "2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image + "ⓈA â’»Charles â’·11 â“P â’¸C â“€F â“„A â“…14331", # Ground truth + "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", # Prediction + "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14331", # LM prediction + 0.5, # WER + ), + ( + "ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image + "ⓈNaudin â’»Marie â’·53 â“S â’¸V â“€Belle mère", # Ground truth + "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", # Prediction + "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", # LM prediction + 0.1429, # WER + ), + ], + ), + ), +) +def test_evaluate_language_model( + capsys, evaluate_config, language_model_weight, expected_inferences, monkeypatch +): + # LM predictions are never used/displayed + # We mock the `Inference` class to temporary check the results + global nb_inferences + nb_inferences = 0 + + class MockInference(Inference): + def __new__(cls, *args, **kwargs): + global nb_inferences + assert args == expected_inferences[nb_inferences] + nb_inferences += 1 + + return super().__new__(cls, *args, **kwargs) + + monkeypatch.setattr("dan.ocr.manager.training.Inference", MockInference) + + # Use the tmp_path as base folder + evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate" + + # Use a LM decoder + evaluate_config["model"]["lm"] = { + "path": PREDICTION_DATA_PATH / "language_model.arpa", + "weight": language_model_weight, + } + + evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD) + + # Check that the evaluation results are correct + for split_name, expected_res in [ + ( + "train", + { + "nb_chars": 90, + "cer": 0.1889, + "nb_chars_no_token": 76, + "cer_no_token": 0.2105, + "nb_words": 15, + "wer": 0.2667, + "nb_words_no_punct": 15, + "wer_no_punct": 0.2667, + "nb_words_no_token": 15, + "wer_no_token": 0.2667, + "nb_tokens": 14, + "ner": 0.0714, + "nb_samples": 2, + }, + ), + ( + "val", + { + "nb_chars": 34, + "cer": 0.0882, + "nb_chars_no_token": 26, + "cer_no_token": 0.1154, + "nb_words": 8, + "wer": 0.5, + "nb_words_no_punct": 8, + "wer_no_punct": 0.5, + "nb_words_no_token": 8, + "wer_no_token": 0.5, + "nb_tokens": 8, + "ner": 0.0, + "nb_samples": 1, + }, + ), + ( + "test", + { + "nb_chars": 36, + "cer": 0.0278, + "nb_chars_no_token": 30, + "cer_no_token": 0.0333, + "nb_words": 7, + "wer": 0.1429, + "nb_words_no_punct": 7, + "wer_no_punct": 0.1429, + "nb_words_no_token": 7, + "wer_no_token": 0.1429, + "nb_tokens": 6, + "ner": 0.0, + "nb_samples": 1, + }, + ), + ]: + filename = ( + evaluate_config["training"]["output_folder"] + / "results" + / f"predict_training-{split_name}_1685.yaml" + ) + + with filename.open() as f: + assert { metric: value for metric, value in yaml.safe_load(f).items() + # Remove the times from the results as they vary if "time" not in metric - } - assert res == expected_res + } == expected_res # Remove results files shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")