# -*- coding: utf-8 -*-

import pytest
import torch
import yaml

from dan.ocr.document.train import train_and_test
from tests.conftest import FIXTURES


@pytest.mark.parametrize(
    "expected_best_model_name, expected_last_model_name, training_res, val_res, test_res",
    (
        (
            "best_0.pt",
            "last_3.pt",
            {
                "nb_chars": 43,
                "cer": 1.2791,
                "nb_words": 9,
                "wer": 1.0,
                "nb_words_no_punct": 9,
                "wer_no_punct": 1.0,
                "nb_samples": 2,
            },
            {
                "nb_chars": 41,
                "cer": 1.2683,
                "nb_words": 9,
                "wer": 1.0,
                "nb_words_no_punct": 9,
                "wer_no_punct": 1.0,
                "nb_samples": 2,
            },
            {
                "nb_chars": 49,
                "cer": 1.1429,
                "nb_words": 9,
                "wer": 1.0,
                "nb_words_no_punct": 9,
                "wer_no_punct": 1.0,
                "nb_samples": 2,
            },
        ),
    ),
)
def test_train_and_test(
    expected_best_model_name,
    expected_last_model_name,
    training_res,
    val_res,
    test_res,
    training_config,
    tmp_path,
):
    # Use the tmp_path as base folder
    training_config["training_params"]["output_folder"] = str(
        tmp_path / training_config["training_params"]["output_folder"]
    )

    train_and_test(0, training_config)

    # Check that the trained model is correct
    for model_name in [expected_best_model_name, expected_last_model_name]:
        expected_model = torch.load(FIXTURES / "training" / "models" / model_name)
        trained_model = torch.load(
            tmp_path
            / training_config["training_params"]["output_folder"]
            / "checkpoints"
            / model_name,
        )

        # Check the optimizers parameters
        for trained, expected in zip(
            trained_model["optimizers_named_params"]["encoder"],
            expected_model["optimizers_named_params"]["encoder"],
        ):
            for (trained_param, trained_tensor), (
                expected_param,
                expected_tensor,
            ) in zip(trained.items(), expected.items()):
                assert trained_param == expected_param
                assert torch.allclose(trained_tensor, expected_tensor, atol=1e-03)

        # Check the optimizer encoder and decoder state dicts
        for optimizer_part in [
            "optimizer_encoder_state_dict",
            "optimizer_decoder_state_dict",
        ]:
            for trained, expected in zip(
                trained_model[optimizer_part]["state"].values(),
                expected_model[optimizer_part]["state"].values(),
            ):
                for (trained_param, trained_tensor), (
                    expected_param,
                    expected_tensor,
                ) in zip(trained.items(), expected.items()):
                    assert trained_param == expected_param
                    assert torch.allclose(
                        trained_tensor,
                        expected_tensor,
                        atol=1e-04,
                    )
            assert (
                trained_model[optimizer_part]["param_groups"]
                == expected_model[optimizer_part]["param_groups"]
            )

        # Check the encoder and decoder weights
        for model_part in ["encoder_state_dict", "decoder_state_dict"]:
            for (trained_name, trained_layer), (expected_name, expected_layer) in zip(
                trained_model[model_part].items(), expected_model[model_part].items()
            ):
                assert trained_name == expected_name
                assert torch.allclose(
                    trained_layer,
                    expected_layer,
                    atol=1e-03,
                )

        # Check the other information
        for elt in [
            "epoch",
            "step",
            "scaler_state_dict",
            "best",
            "charset",
            "curriculum_config",
        ]:
            assert trained_model[elt] == expected_model[elt]

    # Check that the evaluation results are correct
    for split_name, expected_res in zip(
        ["train", "val", "test"], [training_res, val_res, test_res]
    ):
        with (
            tmp_path
            / training_config["training_params"]["output_folder"]
            / "results"
            / f"predict_training-{split_name}_0.yaml"
        ).open() as f:
            # Remove the times from the results as they vary
            res = {
                metric: value
                for metric, value in yaml.safe_load(f).items()
                if "time" not in metric
            }
            assert res == expected_res