Skip to content
Snippets Groups Projects
test_evaluate.py 3.66 KiB
Newer Older
# -*- coding: utf-8 -*-

import shutil

import pytest
import yaml
from prettytable import PrettyTable

from dan.ocr import evaluate
from dan.ocr.utils import add_metrics_table_row, create_metrics_table
from tests import FIXTURES


def test_create_metrics_table():
    metric_names = ["ignored", "wer", "cer", "time", "ner"]
    metrics_table = create_metrics_table(metric_names)

    assert isinstance(metrics_table, PrettyTable)
    assert metrics_table.field_names == [
        "Split",
        "CER (HTR-NER)",
        "WER (HTR-NER)",
        "NER",
    ]


def test_add_metrics_table_row():
    metric_names = ["ignored", "wer", "cer", "time", "ner"]
    metrics_table = create_metrics_table(metric_names)

    metrics = {
        "ignored": "whatever",
        "wer": 1.0,
        "cer": 1.3023,
        "time": 42,
    }
    add_metrics_table_row(metrics_table, "train", metrics)

    assert isinstance(metrics_table, PrettyTable)
    assert metrics_table.field_names == [
        "Split",
        "CER (HTR-NER)",
        "WER (HTR-NER)",
        "NER",
    ]
    assert metrics_table.rows == [["train", 130.23, 100, ""]]
@pytest.mark.parametrize(
    "training_res, val_res, test_res",
    (
        (
            {
                "nb_chars": 43,
                "cer": 1.3023,
Eva Bardou's avatar
Eva Bardou committed
                "nb_chars_no_token": 43,
                "cer_no_token": 1.3023,
                "nb_words": 9,
                "wer": 1.0,
                "nb_words_no_punct": 9,
                "wer_no_punct": 1.0,
Eva Bardou's avatar
Eva Bardou committed
                "nb_words_no_token": 9,
                "wer_no_token": 1.0,
                "nb_samples": 2,
            },
            {
                "nb_chars": 41,
                "cer": 1.2683,
Eva Bardou's avatar
Eva Bardou committed
                "nb_chars_no_token": 41,
                "cer_no_token": 1.2683,
                "nb_words": 9,
                "wer": 1.0,
                "nb_words_no_punct": 9,
                "wer_no_punct": 1.0,
Eva Bardou's avatar
Eva Bardou committed
                "nb_words_no_token": 9,
                "wer_no_token": 1.0,
                "nb_samples": 2,
            },
            {
                "nb_chars": 49,
                "cer": 1.1224,
Eva Bardou's avatar
Eva Bardou committed
                "nb_chars_no_token": 49,
                "cer_no_token": 1.1224,
                "nb_words": 9,
                "wer": 1.0,
                "nb_words_no_punct": 9,
                "wer_no_punct": 1.0,
Eva Bardou's avatar
Eva Bardou committed
                "nb_words_no_token": 9,
                "wer_no_token": 1.0,
                "nb_samples": 2,
            },
        ),
    ),
)
def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
    # Use the tmp_path as base folder
    evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"

    evaluate.run(evaluate_config)

    # Check that the evaluation results are correct
    for split_name, expected_res in zip(
        ["train", "val", "test"], [training_res, val_res, test_res]
    ):
        filename = (
            evaluate_config["training"]["output_folder"]
            / "results"
            / f"predict_training-{split_name}_0.yaml"
        )

        with filename.open() as f:
            # Remove the times from the results as they vary
            res = {
                metric: value
                for metric, value in yaml.safe_load(f).items()
                if "time" not in metric
            }
            assert res == expected_res

    # Remove results files
    shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")

    # Check the metrics Markdown table
    captured_std = capsys.readouterr()
    last_printed_lines = captured_std.out.split("\n")[-6:]
    assert (
        "\n".join(last_printed_lines)
        == Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()
    )