# -*- coding: utf-8 -*- import pytest import torch import yaml from dan.ocr.document.train import train_and_test from tests.conftest import FIXTURES @pytest.mark.parametrize( "expected_best_model_name, expected_last_model_name, training_res, val_res, test_res", ( ( "best_0.pt", "last_3.pt", { "nb_chars": 43, "cer": 1.2791, "nb_words": 9, "wer": 1.0, "nb_words_no_punct": 9, "wer_no_punct": 1.0, "nb_samples": 2, }, { "nb_chars": 41, "cer": 1.2683, "nb_words": 9, "wer": 1.0, "nb_words_no_punct": 9, "wer_no_punct": 1.0, "nb_samples": 2, }, { "nb_chars": 49, "cer": 1.1429, "nb_words": 9, "wer": 1.0, "nb_words_no_punct": 9, "wer_no_punct": 1.0, "nb_samples": 2, }, ), ), ) def test_train_and_test( expected_best_model_name, expected_last_model_name, training_res, val_res, test_res, training_config, tmp_path, ): # Use the tmp_path as base folder training_config["training_params"]["output_folder"] = str( tmp_path / training_config["training_params"]["output_folder"] ) train_and_test(0, training_config) # Check that the trained model is correct for model_name in [expected_best_model_name, expected_last_model_name]: expected_model = torch.load(FIXTURES / "training" / "models" / model_name) trained_model = torch.load( tmp_path / training_config["training_params"]["output_folder"] / "checkpoints" / model_name, ) # Check the optimizers parameters for trained, expected in zip( trained_model["optimizers_named_params"]["encoder"], expected_model["optimizers_named_params"]["encoder"], ): for (trained_param, trained_tensor), ( expected_param, expected_tensor, ) in zip(trained.items(), expected.items()): assert trained_param == expected_param assert torch.allclose(trained_tensor, expected_tensor, atol=1e-03) # Check the optimizer encoder and decoder state dicts for optimizer_part in [ "optimizer_encoder_state_dict", "optimizer_decoder_state_dict", ]: for trained, expected in zip( trained_model[optimizer_part]["state"].values(), expected_model[optimizer_part]["state"].values(), ): for (trained_param, trained_tensor), ( expected_param, expected_tensor, ) in zip(trained.items(), expected.items()): assert trained_param == expected_param assert torch.allclose( trained_tensor, expected_tensor, atol=1e-04, ) assert ( trained_model[optimizer_part]["param_groups"] == expected_model[optimizer_part]["param_groups"] ) # Check the encoder and decoder weights for model_part in ["encoder_state_dict", "decoder_state_dict"]: for (trained_name, trained_layer), (expected_name, expected_layer) in zip( trained_model[model_part].items(), expected_model[model_part].items() ): assert trained_name == expected_name assert torch.allclose( trained_layer, expected_layer, atol=1e-03, ) # Check the other information for elt in [ "epoch", "step", "scaler_state_dict", "best", "charset", "curriculum_config", ]: assert trained_model[elt] == expected_model[elt] # Check that the evaluation results are correct for split_name, expected_res in zip( ["train", "val", "test"], [training_res, val_res, test_res] ): with ( tmp_path / training_config["training_params"]["output_folder"] / "results" / f"predict_training-{split_name}_0.yaml" ).open() as f: # Remove the times from the results as they vary res = { metric: value for metric, value in yaml.safe_load(f).items() if "time" not in metric } assert res == expected_res