# -*- coding: utf-8 -*- import shutil import pytest import yaml from dan.ocr import evaluate from tests import FIXTURES @pytest.mark.parametrize( "training_res, val_res, test_res", ( ( { "nb_chars": 43, "cer": 1.3023, "nb_words": 9, "wer": 1.0, "nb_words_no_punct": 9, "wer_no_punct": 1.0, "nb_samples": 2, }, { "nb_chars": 41, "cer": 1.2683, "nb_words": 9, "wer": 1.0, "nb_words_no_punct": 9, "wer_no_punct": 1.0, "nb_samples": 2, }, { "nb_chars": 49, "cer": 1.1224, "nb_words": 9, "wer": 1.0, "nb_words_no_punct": 9, "wer_no_punct": 1.0, "nb_samples": 2, }, ), ), ) def test_evaluate(training_res, val_res, test_res, evaluate_config): # Use the tmp_path as base folder evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate" evaluate.run(evaluate_config) # Check that the evaluation results are correct for split_name, expected_res in zip( ["train", "val", "test"], [training_res, val_res, test_res] ): filename = ( evaluate_config["training"]["output_folder"] / "results" / f"predict_training-{split_name}_0.yaml" ) with filename.open() as f: # Remove the times from the results as they vary res = { metric: value for metric, value in yaml.safe_load(f).items() if "time" not in metric } assert res == expected_res # Remove results files shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")