Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
# -*- coding: utf-8 -*-
import shutil
import pytest
import yaml
from dan.ocr import evaluate
from tests import FIXTURES
@pytest.mark.parametrize(
"training_res, val_res, test_res",
(
(
{
"nb_chars": 43,
"cer": 1.3023,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_samples": 2,
},
{
"nb_chars": 41,
"cer": 1.2683,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_samples": 2,
},
{
"nb_chars": 49,
"cer": 1.1224,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_samples": 2,
},
),
),
)
def test_evaluate(training_res, val_res, test_res, evaluate_config):
# Use the tmp_path as base folder
evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"
evaluate.run(evaluate_config)
# Check that the evaluation results are correct
for split_name, expected_res in zip(
["train", "val", "test"], [training_res, val_res, test_res]
):
filename = (
evaluate_config["training"]["output_folder"]
/ "results"
/ f"predict_training-{split_name}_0.yaml"
)
with filename.open() as f:
# Remove the times from the results as they vary
res = {
metric: value
for metric, value in yaml.safe_load(f).items()
if "time" not in metric
}
assert res == expected_res
# Remove results files
shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")
......@@ -37,16 +37,31 @@ PREDICTION_DATA_PATH = FIXTURES / "prediction"
),
),
)
def test_predict(image_name, expected_prediction):
@pytest.mark.parametrize("normalize", (True, False))
def test_predict(image_name, expected_prediction, normalize, tmp_path):
# Update mean/std in parameters.yml
model_path = tmp_path / "models"
model_path.mkdir(exist_ok=True)
shutil.copyfile(
PREDICTION_DATA_PATH / "model.pt",
model_path / "model.pt",
)
shutil.copyfile(
PREDICTION_DATA_PATH / "charset.pkl",
model_path / "charset.pkl",
)
params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml")
if not normalize:
del params["parameters"]["mean"]
del params["parameters"]["std"]
yaml.dump(params, (model_path / "parameters.yml").open("w"))
device = "cpu"
dan_model = DAN(device)
dan_model.load(
model_path=PREDICTION_DATA_PATH / "popp_line_model.pt",
params_path=PREDICTION_DATA_PATH / "parameters.yml",
charset_path=PREDICTION_DATA_PATH / "charset.pkl",
mode="eval",
)
dan_model.load(path=model_path, mode="eval")
image_path = PREDICTION_DATA_PATH / "images" / image_name
_, image = dan_model.preprocess(str(image_path))
......@@ -298,12 +313,17 @@ def test_run_prediction(
expected_prediction,
tmp_path,
):
# Make tmpdir and copy needed image inside
image_dir = tmp_path / "images"
image_dir.mkdir()
shutil.copyfile(
(PREDICTION_DATA_PATH / "images" / image_name).with_suffix(".png"),
(image_dir / image_name).with_suffix(".png"),
)
run_prediction(
image=(PREDICTION_DATA_PATH / "images" / image_name).with_suffix(".png"),
image_dir=None,
model=PREDICTION_DATA_PATH / "popp_line_model.pt",
parameters=PREDICTION_DATA_PATH / "parameters.yml",
charset=PREDICTION_DATA_PATH / "charset.pkl",
image_dir=image_dir,
model=PREDICTION_DATA_PATH,
output=tmp_path,
confidence_score=True if confidence_score else False,
confidence_score_levels=confidence_score if confidence_score else [],
......@@ -314,10 +334,8 @@ def test_run_prediction(
line_separators=["\n"],
temperature=temperature,
predict_objects=False,
threshold_method="otsu",
threshold_value=0,
max_object_height=None,
image_extension=None,
image_extension=".png",
gpu_device=None,
batch_size=1,
tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"),
......@@ -497,11 +515,8 @@ def test_run_prediction_batch(
)
run_prediction(
image=None,
image_dir=image_dir,
model=PREDICTION_DATA_PATH / "popp_line_model.pt",
parameters=PREDICTION_DATA_PATH / "parameters.yml",
charset=PREDICTION_DATA_PATH / "charset.pkl",
model=PREDICTION_DATA_PATH,
output=tmp_path,
confidence_score=True if confidence_score else False,
confidence_score_levels=confidence_score if confidence_score else [],
......@@ -512,8 +527,6 @@ def test_run_prediction_batch(
line_separators=["\n"],
temperature=temperature,
predict_objects=False,
threshold_method="otsu",
threshold_value=0,
max_object_height=None,
image_extension=".png",
gpu_device=None,
......@@ -644,16 +657,25 @@ def test_run_prediction_language_model(
)
# Update language_model_weight in parameters.yml
model_path = tmp_path / "models"
model_path.mkdir(exist_ok=True)
shutil.copyfile(
PREDICTION_DATA_PATH / "model.pt",
model_path / "model.pt",
)
shutil.copyfile(
PREDICTION_DATA_PATH / "charset.pkl",
model_path / "charset.pkl",
)
params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml")
params["parameters"]["language_model"]["weight"] = language_model_weight
yaml.dump(params, (tmp_path / "parameters.yml").open("w"))
yaml.dump(params, (model_path / "parameters.yml").open("w"))
run_prediction(
image=None,
image_dir=image_dir,
model=PREDICTION_DATA_PATH / "popp_line_model.pt",
parameters=tmp_path / "parameters.yml",
charset=PREDICTION_DATA_PATH / "charset.pkl",
model=model_path,
output=tmp_path,
confidence_score=False,
confidence_score_levels=[],
......@@ -664,8 +686,6 @@ def test_run_prediction_language_model(
line_separators=["\n"],
temperature=1.0,
predict_objects=False,
threshold_method="otsu",
threshold_value=0,
max_object_height=None,
image_extension=".png",
gpu_device=None,
......
......@@ -6,43 +6,17 @@ import pytest
import torch
import yaml
from dan.ocr.train import train_and_test
from dan.ocr.train import train
from dan.ocr.utils import update_config
from tests.conftest import FIXTURES
@pytest.mark.parametrize(
"expected_best_model_name, expected_last_model_name, training_res, val_res, test_res, params_res",
"expected_best_model_name, expected_last_model_name, params_res",
(
(
"best_0.pt",
"last_3.pt",
{
"nb_chars": 43,
"cer": 1.3023,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_samples": 2,
},
{
"nb_chars": 41,
"cer": 1.2683,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_samples": 2,
},
{
"nb_chars": 49,
"cer": 1.1224,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_samples": 2,
},
{
"parameters": {
"max_char_prediction": 30,
......@@ -79,22 +53,21 @@ from tests.conftest import FIXTURES
),
),
)
def test_train_and_test(
def test_train(
expected_best_model_name,
expected_last_model_name,
training_res,
val_res,
test_res,
params_res,
training_config,
tmp_path,
):
update_config(training_config)
# Use the tmp_path as base folder
training_config["training"]["output_folder"] = (
tmp_path / training_config["training"]["output_folder"]
)
train_and_test(0, training_config)
train(0, training_config)
# There should only be two checkpoints left
checkpoints = (
......@@ -175,24 +148,6 @@ def test_train_and_test(
]:
assert trained_model[elt] == expected_model[elt]
# Check that the evaluation results are correct
for split_name, expected_res in zip(
["train", "val", "test"], [training_res, val_res, test_res]
):
with (
tmp_path
/ training_config["training"]["output_folder"]
/ "results"
/ f"predict_training-{split_name}_0.yaml"
).open() as f:
# Remove the times from the results as they vary
res = {
metric: value
for metric, value in yaml.safe_load(f).items()
if "time" not in metric
}
assert res == expected_res
# Check that the inference parameters file is correct
res = yaml.safe_load(
(
......