Newer
Older
from pathlib import Path
from prettytable import PrettyTable
from dan.ocr.utils import add_metrics_table_row, create_metrics_table
def test_create_metrics_table():
metric_names = ["ignored", "wer", "cer", "time", "ner"]
metrics_table = create_metrics_table(metric_names)
assert isinstance(metrics_table, PrettyTable)
assert metrics_table.field_names == [
"Split",
"CER (HTR-NER)",
"WER (HTR-NER)",
"NER",
]
def test_add_metrics_table_row():
metric_names = ["ignored", "wer", "cer", "time", "ner"]
metrics_table = create_metrics_table(metric_names)
metrics = {
"ignored": "whatever",
"wer": 1.0,
"cer": 1.3023,
"time": 42,
}
add_metrics_table_row(metrics_table, "train", metrics)
assert isinstance(metrics_table, PrettyTable)
assert metrics_table.field_names == [
"Split",
"CER (HTR-NER)",
"WER (HTR-NER)",
"NER",
]
assert metrics_table.rows == [["train", 130.23, 100, "−"]]
@pytest.mark.parametrize(
"training_res, val_res, test_res",
(
(
{
"nb_chars": 90,
"cer": 0.1889,
"nb_chars_no_token": 76,
"cer_no_token": 0.2105,
"nb_words": 15,
"wer": 0.2667,
"nb_words_no_punct": 15,
"wer_no_punct": 0.2667,
"nb_words_no_token": 15,
"wer_no_token": 0.2667,
"nb_tokens": 14,
"ner": 0.0714,
"nb_chars": 34,
"cer": 0.0882,
"nb_chars_no_token": 26,
"cer_no_token": 0.1154,
"nb_words": 8,
"wer": 0.5,
"nb_words_no_punct": 8,
"wer_no_punct": 0.5,
"nb_words_no_token": 8,
"wer_no_token": 0.5,
"nb_tokens": 8,
"ner": 0.0,
"nb_samples": 1,
"nb_chars": 36,
"cer": 0.0278,
"nb_chars_no_token": 30,
"cer_no_token": 0.0333,
"nb_words": 7,
"wer": 0.1429,
"nb_words_no_punct": 7,
"wer_no_punct": 0.1429,
"nb_words_no_token": 7,
"wer_no_token": 0.1429,
"nb_tokens": 6,
"ner": 0.0,
"nb_samples": 1,
def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
# Use the tmp_path as base folder
evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"
evaluate.run(evaluate_config)
# Check that the evaluation results are correct
for split_name, expected_res in zip(
["train", "val", "test"], [training_res, val_res, test_res]
):
filename = (
evaluate_config["training"]["output_folder"]
/ "results"
)
with filename.open() as f:
# Remove the times from the results as they vary
res = {
metric: value
for metric, value in yaml.safe_load(f).items()
if "time" not in metric
}
assert res == expected_res
# Remove results files
shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")
# Check the metrics Markdown table
captured_std = capsys.readouterr()
last_printed_lines = captured_std.out.split("\n")[-6:]
assert (
"\n".join(last_printed_lines)
== Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()
)