Skip to content
Snippets Groups Projects
Commit 2c0b4629 authored by Solene Tarride's avatar Solene Tarride
Browse files

Write tests for LM decoding

parent fc6d283b
No related branches found
No related tags found
No related merge requests found
......@@ -92,7 +92,7 @@ class DAN:
self.decoder = decoder
self.lm_decoder = None
if use_language_model:
if use_language_model and parameters["lm_decoder"]["language_model_weight"] > 0:
self.lm_decoder = CTCLanguageDecoder(
language_model_path=parameters["lm_decoder"]["language_model_path"],
lexicon_path=parameters["lm_decoder"]["lexicon_path"],
......@@ -479,6 +479,9 @@ def run(
)
batch_size = 1 if use_language_model else batch_size
# Do not use LM with invalid LM weight
use_language_model = dan_model.lm_decoder is not None
images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_batch in list_to_batches(images, n=batch_size):
process_batch(
......
......@@ -22,8 +22,8 @@ parameters:
- type: "max_resize"
max_height: 1500
max_width: 1500
language_model:
model: tests/data/prediction/language_model.arpa
lexicon: tests/data/prediction/language_lexicon.txt
tokens: tests/data/prediction/language_tokens.txt
weight: 1.0
lm_decoder:
language_model_path: tests/data/prediction/language_model.arpa
lexicon_path: tests/data/prediction/language_lexicon.txt
tokens_path: tests/data/prediction/language_tokens.txt
language_model_weight: 1.0
......@@ -506,19 +506,28 @@ def test_run_prediction_batch(
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {
<<<<<<< HEAD
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidence": 0.92,
=======
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"
>>>>>>> c80c413 (Write tests for LM decoding)
},
},
{
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
"language_model": {
<<<<<<< HEAD
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
"confidence": 0.88,
=======
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376"
>>>>>>> c80c413 (Write tests for LM decoding)
},
},
{
"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
<<<<<<< HEAD
"language_model": {
"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
"confidence": 0.86,
......@@ -530,6 +539,13 @@ def test_run_prediction_batch(
"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
"confidence": 0.89,
},
=======
"language_model": {"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ1431"},
},
{
"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
"language_model": {"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
>>>>>>> c80c413 (Write tests for LM decoding)
},
],
),
......@@ -545,19 +561,28 @@ def test_run_prediction_batch(
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {
<<<<<<< HEAD
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidence": 0.90,
=======
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"
>>>>>>> c80c413 (Write tests for LM decoding)
},
},
{
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
"language_model": {
<<<<<<< HEAD
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
"confidence": 0.84,
=======
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376"
>>>>>>> c80c413 (Write tests for LM decoding)
},
},
{
"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
<<<<<<< HEAD
"language_model": {
"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14331",
"confidence": 0.83,
......@@ -569,6 +594,13 @@ def test_run_prediction_batch(
"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
"confidence": 0.86,
},
=======
"language_model": {"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14331"},
},
{
"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
"language_model": {"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
>>>>>>> c80c413 (Write tests for LM decoding)
},
],
),
......@@ -589,10 +621,18 @@ def test_run_prediction_batch(
),
),
)
<<<<<<< HEAD
=======
@pytest.mark.parametrize("batch_size", [1, 2])
>>>>>>> c80c413 (Write tests for LM decoding)
def test_run_prediction_language_model(
image_names,
language_model_weight,
expected_predictions,
<<<<<<< HEAD
=======
batch_size,
>>>>>>> c80c413 (Write tests for LM decoding)
tmp_path,
):
# Make tmpdir and copy needed images inside
......@@ -606,7 +646,11 @@ def test_run_prediction_language_model(
# Update language_model_weight in parameters.yml
params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml")
<<<<<<< HEAD
params["parameters"]["language_model"]["weight"] = language_model_weight
=======
params["parameters"]["lm_decoder"]["language_model_weight"] = language_model_weight
>>>>>>> c80c413 (Write tests for LM decoding)
yaml.dump(params, (tmp_path / "parameters.yml").open("w"))
run_prediction(
......@@ -630,7 +674,11 @@ def test_run_prediction_language_model(
max_object_height=None,
image_extension=".png",
gpu_device=None,
<<<<<<< HEAD
batch_size=1,
=======
batch_size=batch_size,
>>>>>>> c80c413 (Write tests for LM decoding)
tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"),
start_token=None,
use_language_model=True,
......@@ -647,7 +695,10 @@ def test_run_prediction_language_model(
prediction["language_model"]["text"]
== expected_prediction["language_model"]["text"]
)
<<<<<<< HEAD
assert np.isclose(
prediction["language_model"]["confidence"],
expected_prediction["language_model"]["confidence"],
)
=======
>>>>>>> c80c413 (Write tests for LM decoding)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment