# -*- coding: utf-8 -*- import pytest import torch import yaml from dan.ocr.train import train_and_test from tests.conftest import FIXTURES @pytest.mark.parametrize( "expected_best_model_name, expected_last_model_name, training_res, val_res, test_res, params_res", ( ( "best_0.pt", "last_3.pt", { "nb_chars": 43, "cer": 1.3023, "nb_words": 9, "wer": 1.0, "nb_words_no_punct": 9, "wer_no_punct": 1.0, "nb_samples": 2, }, { "nb_chars": 41, "cer": 1.2683, "nb_words": 9, "wer": 1.0, "nb_words_no_punct": 9, "wer_no_punct": 1.0, "nb_samples": 2, }, { "nb_chars": 49, "cer": 1.1224, "nb_words": 9, "wer": 1.0, "nb_words_no_punct": 9, "wer_no_punct": 1.0, "nb_samples": 2, }, { "parameters": { "max_char_prediction": 30, "encoder": {"dropout": 0.5}, "decoder": { "enc_dim": 256, "l_max": 15000, "h_max": 500, "w_max": 1000, "dec_num_layers": 8, "dec_num_heads": 4, "dec_res_dropout": 0.1, "dec_pred_dropout": 0.1, "dec_att_dropout": 0.1, "dec_dim_feedforward": 256, "vocab_size": 96, "attention_win": 100, }, "preprocessings": [ { "max_height": 2000, "max_width": 2000, "type": "max_resize", } ], "mean": [ 242.10595854671013, 242.10595854671013, 242.10595854671013, ], "std": [28.29919517652322, 28.29919517652322, 28.29919517652322], }, }, ), ), ) def test_train_and_test( expected_best_model_name, expected_last_model_name, training_res, val_res, test_res, params_res, training_config, tmp_path, ): # Use the tmp_path as base folder training_config["training_params"]["output_folder"] = str( tmp_path / training_config["training_params"]["output_folder"] ) train_and_test(0, training_config) # Check that the trained model is correct for model_name in [expected_best_model_name, expected_last_model_name]: expected_model = torch.load(FIXTURES / "training" / "models" / model_name) trained_model = torch.load( tmp_path / training_config["training_params"]["output_folder"] / "checkpoints" / model_name, ) # Check the optimizers parameters for trained, expected in zip( trained_model["optimizers_named_params"]["encoder"], expected_model["optimizers_named_params"]["encoder"], ): for (trained_param, trained_tensor), ( expected_param, expected_tensor, ) in zip(trained.items(), expected.items()): assert trained_param == expected_param assert torch.allclose(trained_tensor, expected_tensor, atol=1e-03) # Check the optimizer encoder and decoder state dicts for optimizer_part in [ "optimizer_encoder_state_dict", "optimizer_decoder_state_dict", ]: for trained, expected in zip( trained_model[optimizer_part]["state"].values(), expected_model[optimizer_part]["state"].values(), ): for (trained_param, trained_tensor), ( expected_param, expected_tensor, ) in zip(trained.items(), expected.items()): assert trained_param == expected_param assert torch.allclose( trained_tensor, expected_tensor, atol=1e-03, ) assert ( trained_model[optimizer_part]["param_groups"] == expected_model[optimizer_part]["param_groups"] ) # Check the encoder and decoder weights for model_part in ["encoder_state_dict", "decoder_state_dict"]: for (trained_name, trained_layer), (expected_name, expected_layer) in zip( trained_model[model_part].items(), expected_model[model_part].items() ): assert trained_name == expected_name assert torch.allclose( trained_layer, expected_layer, atol=1e-03, ) # Check the other information for elt in [ "epoch", "step", "scaler_state_dict", "best", "charset", "curriculum_config", ]: assert trained_model[elt] == expected_model[elt] # Check that the evaluation results are correct for split_name, expected_res in zip( ["train", "val", "test"], [training_res, val_res, test_res] ): with ( tmp_path / training_config["training_params"]["output_folder"] / "results" / f"predict_training-{split_name}_0.yaml" ).open() as f: # Remove the times from the results as they vary res = { metric: value for metric, value in yaml.safe_load(f).items() if "time" not in metric } assert res == expected_res # Check that the inference parameters file is correct with ( tmp_path / training_config["training_params"]["output_folder"] / "results" / "inference_parameters.yml" ).open() as f: res = yaml.safe_load(f) assert res == params_res