Skip to content
Snippets Groups Projects
Commit c80c4139 authored by Solene Tarride's avatar Solene Tarride
Browse files

Write tests for LM decoding

parent 43ccb808
No related branches found
No related tags found
No related merge requests found
......@@ -93,7 +93,7 @@ class DAN:
self.decoder = decoder
self.lm_decoder = None
if use_language_model:
if use_language_model and parameters["lm_decoder"]["language_model_weight"] > 0:
self.lm_decoder = CTCLanguageDecoder(
language_model_path=parameters["lm_decoder"]["language_model_path"],
lexicon_path=parameters["lm_decoder"]["lexicon_path"],
......@@ -509,6 +509,9 @@ def run(
model, parameters, charset, mode="eval", use_language_model=use_language_model
)
# Do not use LM with invalid LM weight
use_language_model = dan_model.lm_decoder is not None
images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_batch in list_to_batches(images, n=batch_size):
process_batch(
......
⎵ ⎵
! !
" "
& &
' '
( (
) )
- -
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9
: :
? ?
A A
B B
C C
D D
E E
F F
G G
H H
I I
J J
K K
L L
M M
N N
O O
P P
Q Q
R R
S S
T T
U U
V V
W W
X X
Y Y
Z Z
a a
b b
c c
d d
e e
f f
g g
h h
i i
j j
k k
l l
m m
n n
o o
p p
q q
r r
s s
t t
u u
v v
w w
x x
y y
z z
° °
à à
â â
ç ç
è è
é é
ê ê
ë ë
î î
ï ï
ô ô
ö ö
ù ù
ü ü
Ⓑ Ⓑ
Ⓒ Ⓒ
Ⓔ Ⓔ
Ⓕ Ⓕ
Ⓚ Ⓚ
Ⓛ Ⓛ
Ⓝ Ⓝ
Ⓞ Ⓞ
Ⓟ Ⓟ
Ⓢ Ⓢ
◌ ◌
This diff is collapsed.
!
"
&
'
(
)
-
0
1
2
3
4
5
6
7
8
9
:
?
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
°
à
â
ç
è
é
ê
ë
î
ï
ô
ö
ù
ü
......@@ -22,3 +22,8 @@ parameters:
- type: "max_resize"
max_height: 1500
max_width: 1500
lm_decoder:
language_model_path: tests/data/prediction/language_model.arpa
lexicon_path: tests/data/prediction/language_lexicon.txt
tokens_path: tests/data/prediction/language_tokens.txt
language_model_weight: 1.0
......@@ -4,11 +4,12 @@ import json
import shutil
import pytest
import yaml
from dan.ocr.predict.attention import Level
from dan.ocr.predict.prediction import DAN
from dan.ocr.predict.prediction import run as run_prediction
from dan.utils import parse_tokens
from dan.utils import parse_tokens, read_yaml
from tests import FIXTURES
PREDICTION_DATA_PATH = FIXTURES / "prediction"
......@@ -527,3 +528,147 @@ def test_run_prediction_batch(
(tmp_path / image_name).with_suffix(".json").read_text()
)
assert prediction == expected_prediction
@pytest.mark.parametrize(
"image_names, language_model_weight, expected_predictions",
(
(
[
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
"0dfe8bcd-ed0b-453e-bf19-cc697012296e",
"2c242f5c-e979-43c4-b6f2-a6d4815b651d",
"ffdec445-7f14-4f5f-be44-68d0844d0df1",
],
1.0,
[
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"
},
},
{
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
"language_model": {
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376"
},
},
{
"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
"language_model": {"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ1431"},
},
{
"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
"language_model": {"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
},
],
),
(
[
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
"0dfe8bcd-ed0b-453e-bf19-cc697012296e",
"2c242f5c-e979-43c4-b6f2-a6d4815b651d",
"ffdec445-7f14-4f5f-be44-68d0844d0df1",
],
2.0,
[
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"
},
},
{
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
"language_model": {
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376"
},
},
{
"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
"language_model": {"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14331"},
},
{
"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
"language_model": {"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
},
],
),
(
[
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
"0dfe8bcd-ed0b-453e-bf19-cc697012296e",
"2c242f5c-e979-43c4-b6f2-a6d4815b651d",
"ffdec445-7f14-4f5f-be44-68d0844d0df1",
],
0.0,
[
{"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"},
{"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376"},
{"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"},
{"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
],
),
),
)
@pytest.mark.parametrize("batch_size", [1, 2])
def test_run_prediction_language_model(
image_names,
language_model_weight,
expected_predictions,
batch_size,
tmp_path,
):
# Make tmpdir and copy needed images inside
image_dir = tmp_path / "images"
image_dir.mkdir()
for image_name in image_names:
shutil.copyfile(
(PREDICTION_DATA_PATH / "images" / image_name).with_suffix(".png"),
(image_dir / image_name).with_suffix(".png"),
)
# Update language_model_weight in parameters.yml
params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml")
params["parameters"]["lm_decoder"]["language_model_weight"] = language_model_weight
yaml.dump(params, (tmp_path / "parameters.yml").open("w"))
run_prediction(
image=None,
image_dir=image_dir,
model=PREDICTION_DATA_PATH / "popp_line_model.pt",
parameters=tmp_path / "parameters.yml",
charset=PREDICTION_DATA_PATH / "charset.pkl",
output=tmp_path,
confidence_score=False,
confidence_score_levels=[],
attention_map=[],
attention_map_level=None,
attention_map_scale=0.5,
word_separators=[" ", "\n"],
line_separators=["\n"],
temperature=1.0,
predict_objects=False,
threshold_method="otsu",
threshold_value=0,
max_object_height=None,
image_extension=".png",
gpu_device=None,
batch_size=batch_size,
tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"),
start_token=None,
use_language_model=True,
)
for image_name, expected_prediction in zip(image_names, expected_predictions):
prediction = json.loads(
(tmp_path / image_name).with_suffix(".json").read_text()
)
assert prediction["text"] == expected_prediction["text"]
if language_model_weight > 0:
assert (
prediction["language_model"]["text"]
== expected_prediction["language_model"]["text"]
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment