# -*- coding: utf-8 -*-

import shutil

import pytest
import yaml

from dan.ocr import evaluate
from tests import FIXTURES


@pytest.mark.parametrize(
    "training_res, val_res, test_res",
    (
        (
            {
                "nb_chars": 43,
                "cer": 1.3023,
                "nb_words": 9,
                "wer": 1.0,
                "nb_words_no_punct": 9,
                "wer_no_punct": 1.0,
                "nb_samples": 2,
            },
            {
                "nb_chars": 41,
                "cer": 1.2683,
                "nb_words": 9,
                "wer": 1.0,
                "nb_words_no_punct": 9,
                "wer_no_punct": 1.0,
                "nb_samples": 2,
            },
            {
                "nb_chars": 49,
                "cer": 1.1224,
                "nb_words": 9,
                "wer": 1.0,
                "nb_words_no_punct": 9,
                "wer_no_punct": 1.0,
                "nb_samples": 2,
            },
        ),
    ),
)
def test_evaluate(training_res, val_res, test_res, evaluate_config):
    # Use the tmp_path as base folder
    evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"

    evaluate.run(evaluate_config)

    # Check that the evaluation results are correct
    for split_name, expected_res in zip(
        ["train", "val", "test"], [training_res, val_res, test_res]
    ):
        filename = (
            evaluate_config["training"]["output_folder"]
            / "results"
            / f"predict_training-{split_name}_0.yaml"
        )

        with filename.open() as f:
            # Remove the times from the results as they vary
            res = {
                metric: value
                for metric, value in yaml.safe_load(f).items()
                if "time" not in metric
            }
            assert res == expected_res

    # Remove results files
    shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")