Skip to content
Snippets Groups Projects
test_training.py 6.34 KiB
# -*- coding: utf-8 -*-

import pytest
import torch
import yaml

from dan.ocr.train import train_and_test
from tests.conftest import FIXTURES


@pytest.mark.parametrize(
    "expected_best_model_name, expected_last_model_name, training_res, val_res, test_res, params_res",
    (
        (
            "best_0.pt",
            "last_3.pt",
            {
                "nb_chars": 43,
                "cer": 1.3023,
                "nb_words": 9,
                "wer": 1.0,
                "nb_words_no_punct": 9,
                "wer_no_punct": 1.0,
                "nb_samples": 2,
            },
            {
                "nb_chars": 41,
                "cer": 1.2683,
                "nb_words": 9,
                "wer": 1.0,
                "nb_words_no_punct": 9,
                "wer_no_punct": 1.0,
                "nb_samples": 2,
            },
            {
                "nb_chars": 49,
                "cer": 1.1224,
                "nb_words": 9,
                "wer": 1.0,
                "nb_words_no_punct": 9,
                "wer_no_punct": 1.0,
                "nb_samples": 2,
            },
            {
                "parameters": {
                    "max_char_prediction": 30,
                    "encoder": {"dropout": 0.5},
                    "decoder": {
                        "enc_dim": 256,
                        "l_max": 15000,
                        "h_max": 500,
                        "w_max": 1000,
                        "dec_num_layers": 8,
                        "dec_num_heads": 4,
                        "dec_res_dropout": 0.1,
                        "dec_pred_dropout": 0.1,
                        "dec_att_dropout": 0.1,
                        "dec_dim_feedforward": 256,
                        "vocab_size": 96,
                        "attention_win": 100,
                    },
                    "preprocessings": [
                        {
                            "max_height": 2000,
                            "max_width": 2000,
                            "type": "max_resize",
                        }
                    ],
                    "mean": [
                        242.10595854671013,
                        242.10595854671013,
                        242.10595854671013,
                    ],
                    "std": [28.29919517652322, 28.29919517652322, 28.29919517652322],
                },
            },
        ),
    ),
)
def test_train_and_test(
    expected_best_model_name,
    expected_last_model_name,
    training_res,
    val_res,
    test_res,
    params_res,
    training_config,
    tmp_path,
):
    # Use the tmp_path as base folder
    training_config["training"]["output_folder"] = str(
        tmp_path / training_config["training"]["output_folder"]
    )

    train_and_test(0, training_config)

    # Check that the trained model is correct
    for model_name in [expected_best_model_name, expected_last_model_name]:
        expected_model = torch.load(FIXTURES / "training" / "models" / model_name)
        trained_model = torch.load(
            tmp_path
            / training_config["training"]["output_folder"]
            / "checkpoints"
            / model_name,
        )

        # Check the optimizers parameters
        for trained, expected in zip(
            trained_model["optimizers_named_params"]["encoder"],
            expected_model["optimizers_named_params"]["encoder"],
        ):
            for (trained_param, trained_tensor), (
                expected_param,
                expected_tensor,
            ) in zip(trained.items(), expected.items()):
                assert trained_param == expected_param
                assert torch.allclose(
                    trained_tensor, expected_tensor, rtol=1e-05, atol=1e-03
                )

        # Check the optimizer encoder and decoder state dicts
        for optimizer_part in [
            "optimizer_encoder_state_dict",
            "optimizer_decoder_state_dict",
        ]:
            for trained, expected in zip(
                trained_model[optimizer_part]["state"].values(),
                expected_model[optimizer_part]["state"].values(),
            ):
                for (trained_param, trained_tensor), (
                    expected_param,
                    expected_tensor,
                ) in zip(trained.items(), expected.items()):
                    assert trained_param == expected_param
                    assert torch.allclose(
                        trained_tensor,
                        expected_tensor,
                        atol=1e-03,
                    )
            assert (
                trained_model[optimizer_part]["param_groups"]
                == expected_model[optimizer_part]["param_groups"]
            )

        # Check the encoder and decoder weights
        for model_part in ["encoder_state_dict", "decoder_state_dict"]:
            for (trained_name, trained_layer), (expected_name, expected_layer) in zip(
                trained_model[model_part].items(), expected_model[model_part].items()
            ):
                assert trained_name == expected_name
                assert torch.allclose(
                    trained_layer,
                    expected_layer,
                    atol=1e-03,
                )

        # Check the other information
        for elt in [
            "epoch",
            "step",
            "scaler_state_dict",
            "best",
            "charset",
            "curriculum_config",
        ]:
            assert trained_model[elt] == expected_model[elt]

    # Check that the evaluation results are correct
    for split_name, expected_res in zip(
        ["train", "val", "test"], [training_res, val_res, test_res]
    ):
        with (
            tmp_path
            / training_config["training"]["output_folder"]
            / "results"
            / f"predict_training-{split_name}_0.yaml"
        ).open() as f:
            # Remove the times from the results as they vary
            res = {
                metric: value
                for metric, value in yaml.safe_load(f).items()
                if "time" not in metric
            }
            assert res == expected_res

    # Check that the inference parameters file is correct
    with (
        tmp_path
        / training_config["training"]["output_folder"]
        / "results"
        / "inference_parameters.yml"
    ).open() as f:
        res = yaml.safe_load(f)
        assert res == params_res