# -*- coding: utf-8 -*- import shutil from pathlib import Path import pytest import yaml from prettytable import PrettyTable from dan.ocr import evaluate from dan.ocr.utils import add_metrics_table_row, create_metrics_table from tests import FIXTURES def test_create_metrics_table(): metric_names = ["ignored", "wer", "cer", "time", "ner"] metrics_table = create_metrics_table(metric_names) assert isinstance(metrics_table, PrettyTable) assert metrics_table.field_names == [ "Split", "CER (HTR-NER)", "WER (HTR-NER)", "NER", ] def test_add_metrics_table_row(): metric_names = ["ignored", "wer", "cer", "time", "ner"] metrics_table = create_metrics_table(metric_names) metrics = { "ignored": "whatever", "wer": 1.0, "cer": 1.3023, "time": 42, } add_metrics_table_row(metrics_table, "train", metrics) assert isinstance(metrics_table, PrettyTable) assert metrics_table.field_names == [ "Split", "CER (HTR-NER)", "WER (HTR-NER)", "NER", ] assert metrics_table.rows == [["train", 130.23, 100, "−"]] @pytest.mark.parametrize( "training_res, val_res, test_res", ( ( { "nb_chars": 90, "cer": 0.1889, "nb_chars_no_token": 76, "cer_no_token": 0.2105, "nb_words": 15, "wer": 0.2667, "nb_words_no_punct": 15, "wer_no_punct": 0.2667, "nb_words_no_token": 15, "wer_no_token": 0.2667, "nb_tokens": 14, "ner": 0.0714, "nb_samples": 2, }, { "nb_chars": 34, "cer": 0.0882, "nb_chars_no_token": 26, "cer_no_token": 0.1154, "nb_words": 8, "wer": 0.5, "nb_words_no_punct": 8, "wer_no_punct": 0.5, "nb_words_no_token": 8, "wer_no_token": 0.5, "nb_tokens": 8, "ner": 0.0, "nb_samples": 1, }, { "nb_chars": 36, "cer": 0.0278, "nb_chars_no_token": 30, "cer_no_token": 0.0333, "nb_words": 7, "wer": 0.1429, "nb_words_no_punct": 7, "wer_no_punct": 0.1429, "nb_words_no_token": 7, "wer_no_token": 0.1429, "nb_tokens": 6, "ner": 0.0, "nb_samples": 1, }, ), ), ) def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config): # Use the tmp_path as base folder evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate" evaluate.run(evaluate_config) # Check that the evaluation results are correct for split_name, expected_res in zip( ["train", "val", "test"], [training_res, val_res, test_res] ): filename = ( evaluate_config["training"]["output_folder"] / "results" / f"predict_training-{split_name}_1685.yaml" ) with filename.open() as f: # Remove the times from the results as they vary res = { metric: value for metric, value in yaml.safe_load(f).items() if "time" not in metric } assert res == expected_res # Remove results files shutil.rmtree(evaluate_config["training"]["output_folder"] / "results") # Check the metrics Markdown table captured_std = capsys.readouterr() last_printed_lines = captured_std.out.split("\n")[-6:] assert ( "\n".join(last_printed_lines) == Path(FIXTURES / "evaluate" / "metrics_table.md").read_text() )