Skip to content
Snippets Groups Projects
Commit 2e719c9e authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Merge branch 'eval-load-lm' into 'main'

Load a language model and decode with it during evaluation

Closes #252

See merge request !347
parents 389c3505 48a4fceb
No related branches found
No related tags found
1 merge request!347Load a language model and decode with it during evaluation
......@@ -32,6 +32,7 @@ class Inference(NamedTuple):
image: str
ground_truth: str
prediction: str
lm_prediction: str
wer: float
......
......@@ -21,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.manager.metrics import Inference, MetricManager
from dan.ocr.manager.ocr import OCRDatasetManager
......@@ -33,6 +33,7 @@ if MLFLOW_AVAILABLE:
import mlflow
logger = logging.getLogger(__name__)
MODEL_NAME_ENCODER = "encoder"
MODEL_NAME_DECODER = "decoder"
MODEL_NAMES = (MODEL_NAME_ENCODER, MODEL_NAME_DECODER)
......@@ -195,6 +196,28 @@ class GenericTrainingManager:
output_device=self.ddp_config["rank"],
)
# Instantiate LM decoder
self.lm_decoder = None
if self.params["model"].get("lm") and self.params["model"]["lm"]["weight"] > 0:
logger.info(
f"Decoding with a language model (weight={self.params['model']['lm']['weight']})."
)
# Check files
model_path = self.params["model"]["lm"]["path"]
assert model_path.is_file(), f"File {model_path} not found"
base_path = model_path.parent
lexicon_path = base_path / "lexicon.txt"
assert lexicon_path.is_file(), f"File {lexicon_path} not found"
tokens_path = base_path / "tokens.txt"
assert tokens_path.is_file(), f"File {tokens_path} not found"
# Load LM decoder
self.lm_decoder = CTCLanguageDecoder(
language_model_path=str(model_path),
lexicon_path=str(lexicon_path),
tokens_path=str(tokens_path),
language_model_weight=self.params["model"]["lm"]["weight"],
)
# Handle curriculum dropout
self.dropout_scheduler = DropoutScheduler(self.models)
......@@ -816,6 +839,7 @@ class GenericTrainingManager:
batch_data["names"],
batch_values["str_y"],
batch_values["str_x"],
batch_values.get("str_lm", repeat("")),
repeat(display_values["wer"]),
)
)
......@@ -1059,6 +1083,13 @@ class Manager(GenericTrainingManager):
)
predicted_tokens_len = torch.ones((b,), dtype=torch.int, device=self.device)
# end token index will be used for ctc
tot_pred = torch.zeros(
(b, len(self.dataset.charset) + 1, max_chars),
dtype=torch.float,
device=self.device,
)
whole_output = list()
confidence_scores = list()
cache = None
......@@ -1112,6 +1143,10 @@ class Manager(GenericTrainingManager):
cache=cache,
num_pred=1,
)
# output total logit prediction
tot_pred[:, :, i : i + 1] = pred
whole_output.append(output)
confidence_scores.append(
torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
......@@ -1158,4 +1193,7 @@ class Manager(GenericTrainingManager):
"confidence_score": confidence_scores,
"time": process_time,
}
if self.lm_decoder:
values["str_lm"] = self.lm_decoder(tot_pred, prediction_len)["text"]
return values
......@@ -37,6 +37,10 @@ def update_config(config: dict):
# .model.decoder.class = GlobalHTADecoder
config["model"]["decoder"]["class"] = GlobalHTADecoder
# .model.lm.path to Path
if config["model"].get("lm", {}).get("path"):
config["model"]["lm"]["path"] = Path(config["model"]["lm"]["path"])
# Update preprocessing type
for prepro in config["training"]["data"]["preprocessings"]:
prepro["type"] = Preprocessing(prepro["type"])
......
......@@ -41,4 +41,4 @@ To train a DAN model, please refer to the [documentation of the training command
## 3. Predict
Once the training is complete, you can apply a trained DAN model on an image using the [predict command](../usage/predict/index.md) and the `inference_parameters.yml` file, located in `{training.output_folder}/results`.
Once the training is complete, you can apply a trained DAN model on an image using the [predict command](../usage/predict/index.md) and the `inference_parameters.yml` file, located in `{training.output_folder}/results`.
......@@ -166,14 +166,13 @@ It will create the following JSON file named after the image and a GIF showing a
This example assumes that you have already [trained a language model](../train/language_model.md).
Note that:
- the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions.
- linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks.
!!! note
- the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions.
- linebreaks are treated as spaces by language models, as a result predictions will not include linebreaks.
#### Language model at character level
First, update the `inference_parameters.yml` file obtained during DAN training.
Update the `parameters.yml` file obtained during DAN training.
```yaml
parameters:
......@@ -185,8 +184,6 @@ parameters:
weight: 0.5
```
Note that the `weight` parameter defines how much weight to give to the language model. It should be set carefully (usually between 0.5 and 2.0) as it will affect the quality of the predictions.
Then, run this command:
```shell
......@@ -211,7 +208,7 @@ It will create the following JSON file named after the image in the `predict_cha
#### Language model at subword level
Update the `inference_parameters.yml` file obtained during DAN training.
Update the `parameters.yml` file obtained during DAN training.
```yaml
parameters:
......@@ -247,7 +244,7 @@ It will create the following JSON file named after the image in the `predict_sub
#### Language model at word level
Update the `inference_parameters.yml` file obtained during DAN training.
Update the `parameters.yml` file obtained during DAN training.
```yaml
parameters:
......
This diff is collapsed.
......@@ -24,6 +24,6 @@ parameters:
max_width: 1500
language_model:
model: tests/data/prediction/language_model.arpa
lexicon: tests/data/prediction/language_lexicon.txt
tokens: tests/data/prediction/language_tokens.txt
lexicon: tests/data/prediction/lexicon.txt
tokens: tests/data/prediction/tokens.txt
weight: 1.0
......@@ -8,9 +8,12 @@ import yaml
from prettytable import PrettyTable
from dan.ocr import evaluate
from dan.ocr.manager.metrics import Inference
from dan.ocr.utils import add_metrics_table_row, create_metrics_table
from tests import FIXTURES
PREDICTION_DATA_PATH = FIXTURES / "prediction"
def test_create_metrics_table():
metric_names = ["ignored", "wer", "cer", "time", "ner"]
......@@ -115,14 +118,228 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
/ f"predict_training-{split_name}_1685.yaml"
)
with filename.open() as f:
assert {
metric: value
for metric, value in yaml.safe_load(filename.read_bytes()).items()
# Remove the times from the results as they vary
res = {
if "time" not in metric
} == expected_res
# Remove results files
shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")
# Check the metrics Markdown table
captured_std = capsys.readouterr()
last_printed_lines = captured_std.out.split("\n")[10:]
assert (
"\n".join(last_printed_lines)
== Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()
)
@pytest.mark.parametrize(
"language_model_weight, expected_inferences",
(
(
0.0,
[
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241", # Ground truth
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # Prediction
"", # LM prediction
0.125, # WER
),
(
"0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle", # Ground truth
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # Prediction
"", # LM prediction
0.2667, # WER
),
(
"2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image
"ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331", # Ground truth
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31", # Prediction
"", # LM prediction
0.5, # WER
),
(
"ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère", # Ground truth
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # Prediction
"", # LM prediction
0.1429, # WER
),
],
),
(
1.0,
[
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241", # Ground truth
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # Prediction
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # LM prediction
0.125, # WER
),
(
"0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle", # Ground truth
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # Prediction
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # LM prediction
0.2667, # WER
),
(
"2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image
"ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331", # Ground truth
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31", # Prediction
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31", # LM prediction
0.5, # WER
),
(
"ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère", # Ground truth
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # Prediction
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # LM prediction
0.1429, # WER
),
],
),
(
2.0,
[
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png", # Image
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241", # Ground truth
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # Prediction
"ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", # LM prediction
0.125, # WER
),
(
"0dfe8bcd-ed0b-453e-bf19-cc697012296e.png", # Image
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle", # Ground truth
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # Prediction
"ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", # LM prediction
0.2667, # WER
),
(
"2c242f5c-e979-43c4-b6f2-a6d4815b651d.png", # Image
"ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331", # Ground truth
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31", # Prediction
"Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14331", # LM prediction
0.5, # WER
),
(
"ffdec445-7f14-4f5f-be44-68d0844d0df1.png", # Image
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère", # Ground truth
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # Prediction
"ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère", # LM prediction
0.1429, # WER
),
],
),
),
)
def test_evaluate_language_model(
capsys, evaluate_config, language_model_weight, expected_inferences, monkeypatch
):
# LM predictions are never used/displayed
# We mock the `Inference` class to temporary check the results
global nb_inferences
nb_inferences = 0
class MockInference(Inference):
def __new__(cls, *args, **kwargs):
global nb_inferences
assert args == expected_inferences[nb_inferences]
nb_inferences += 1
return super().__new__(cls, *args, **kwargs)
monkeypatch.setattr("dan.ocr.manager.training.Inference", MockInference)
# Use the tmp_path as base folder
evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"
# Use a LM decoder
evaluate_config["model"]["lm"] = {
"path": PREDICTION_DATA_PATH / "language_model.arpa",
"weight": language_model_weight,
}
evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD)
# Check that the evaluation results are correct
for split_name, expected_res in [
(
"train",
{
"nb_chars": 90,
"cer": 0.1889,
"nb_chars_no_token": 76,
"cer_no_token": 0.2105,
"nb_words": 15,
"wer": 0.2667,
"nb_words_no_punct": 15,
"wer_no_punct": 0.2667,
"nb_words_no_token": 15,
"wer_no_token": 0.2667,
"nb_tokens": 14,
"ner": 0.0714,
"nb_samples": 2,
},
),
(
"val",
{
"nb_chars": 34,
"cer": 0.0882,
"nb_chars_no_token": 26,
"cer_no_token": 0.1154,
"nb_words": 8,
"wer": 0.5,
"nb_words_no_punct": 8,
"wer_no_punct": 0.5,
"nb_words_no_token": 8,
"wer_no_token": 0.5,
"nb_tokens": 8,
"ner": 0.0,
"nb_samples": 1,
},
),
(
"test",
{
"nb_chars": 36,
"cer": 0.0278,
"nb_chars_no_token": 30,
"cer_no_token": 0.0333,
"nb_words": 7,
"wer": 0.1429,
"nb_words_no_punct": 7,
"wer_no_punct": 0.1429,
"nb_words_no_token": 7,
"wer_no_token": 0.1429,
"nb_tokens": 6,
"ner": 0.0,
"nb_samples": 1,
},
),
]:
filename = (
evaluate_config["training"]["output_folder"]
/ "results"
/ f"predict_training-{split_name}_1685.yaml"
)
with filename.open() as f:
assert {
metric: value
for metric, value in yaml.safe_load(f).items()
# Remove the times from the results as they vary
if "time" not in metric
}
assert res == expected_res
} == expected_res
# Remove results files
shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment