Something went wrong on our end
test_evaluate.py 2.38 KiB
# -*- coding: utf-8 -*-
import shutil
import pytest
import yaml
from dan.ocr import evaluate
from tests import FIXTURES
@pytest.mark.parametrize(
"training_res, val_res, test_res",
(
(
{
"nb_chars": 43,
"cer": 1.3023,
"nb_chars_no_token": 43,
"cer_no_token": 1.3023,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_words_no_token": 9,
"wer_no_token": 1.0,
"nb_samples": 2,
},
{
"nb_chars": 41,
"cer": 1.2683,
"nb_chars_no_token": 41,
"cer_no_token": 1.2683,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_words_no_token": 9,
"wer_no_token": 1.0,
"nb_samples": 2,
},
{
"nb_chars": 49,
"cer": 1.1224,
"nb_chars_no_token": 49,
"cer_no_token": 1.1224,
"nb_words": 9,
"wer": 1.0,
"nb_words_no_punct": 9,
"wer_no_punct": 1.0,
"nb_words_no_token": 9,
"wer_no_token": 1.0,
"nb_samples": 2,
},
),
),
)
def test_evaluate(training_res, val_res, test_res, evaluate_config):
# Use the tmp_path as base folder
evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"
evaluate.run(evaluate_config)
# Check that the evaluation results are correct
for split_name, expected_res in zip(
["train", "val", "test"], [training_res, val_res, test_res]
):
filename = (
evaluate_config["training"]["output_folder"]
/ "results"
/ f"predict_training-{split_name}_0.yaml"
)
with filename.open() as f:
# Remove the times from the results as they vary
res = {
metric: value
for metric, value in yaml.safe_load(f).items()
if "time" not in metric
}
assert res == expected_res
# Remove results files
shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")