Skip to content
Snippets Groups Projects
test_evaluate.py 12.47 KiB
# -*- coding: utf-8 -*-

import shutil
from pathlib import Path

import pytest
import yaml
from prettytable import PrettyTable

from dan.ocr import evaluate
from dan.ocr.manager.metrics import Inference
from dan.ocr.utils import add_metrics_table_row, create_metrics_table
from tests import FIXTURES

PREDICTION_DATA_PATH = FIXTURES / "prediction"


def test_create_metrics_table():
    metric_names = ["ignored", "wer", "cer", "time", "ner"]
    metrics_table = create_metrics_table(metric_names)

    assert isinstance(metrics_table, PrettyTable)
    assert metrics_table.field_names == [
        "Split",
        "CER (HTR-NER)",
        "WER (HTR-NER)",
        "NER",
    ]


def test_add_metrics_table_row():
    metric_names = ["ignored", "wer", "cer", "time", "ner"]
    metrics_table = create_metrics_table(metric_names)

    metrics = {
        "ignored": "whatever",
        "wer": 1.0,
        "cer": 1.3023,
        "time": 42,
    }
    add_metrics_table_row(metrics_table, "train", metrics)

    assert isinstance(metrics_table, PrettyTable)
    assert metrics_table.field_names == [
        "Split",
        "CER (HTR-NER)",
        "WER (HTR-NER)",
        "NER",
    ]
    assert metrics_table.rows == [["train", 130.23, 100, ""]]


@pytest.mark.parametrize(
    "training_res, val_res, test_res",
    (
        (
            {
                "nb_chars": 90,
                "cer": 0.1889,
                "nb_chars_no_token": 76,
                "cer_no_token": 0.2105,
                "nb_words": 15,
                "wer": 0.2667,
                "nb_words_no_punct": 15,
                "wer_no_punct": 0.2667,
                "nb_words_no_token": 15,
                "wer_no_token": 0.2667,
                "nb_tokens": 14,
                "ner": 0.0714,
                "nb_samples": 2,
            },
            {
                "nb_chars": 34,
                "cer": 0.0882,
                "nb_chars_no_token": 26,
                "cer_no_token": 0.1154,
                "nb_words": 8,
                "wer": 0.5,
                "nb_words_no_punct": 8,
                "wer_no_punct": 0.5,
                "nb_words_no_token": 8,
                "wer_no_token": 0.5,
                "nb_tokens": 8,
                "ner": 0.0,
                "nb_samples": 1,
            },
            {
                "nb_chars": 36,
                "cer": 0.0278,
                "nb_chars_no_token": 30,
                "cer_no_token": 0.0333,
                "nb_words": 7,
                "wer": 0.1429,
                "nb_words_no_punct": 7,
                "wer_no_punct": 0.1429,
                "nb_words_no_token": 7,
                "wer_no_token": 0.1429,
                "nb_tokens": 6,
                "ner": 0.0,
                "nb_samples": 1,
            },
        ),
    ),
)
def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
    # Use the tmp_path as base folder
    evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"

    evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD)

    # Check that the evaluation results are correct
    for split_name, expected_res in zip(
        ["train", "val", "test"], [training_res, val_res, test_res]
    ):
        filename = (
            evaluate_config["training"]["output_folder"]
            / "results"
            / f"predict_training-{split_name}_1685.yaml"
        )

        assert {
            metric: value
            for metric, value in yaml.safe_load(filename.read_bytes()).items()
            # Remove the times from the results as they vary
            if "time" not in metric
        } == expected_res

    # Remove results files
    shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")

    # Check the metrics Markdown table
    captured_std = capsys.readouterr()
    last_printed_lines = captured_std.out.split("\n")[10:]
    assert (
        "\n".join(last_printed_lines)
        == Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()
    )


@pytest.mark.parametrize(
    "language_model_weight, expected_inferences",
    (
        (
            0.0,
            [
                (
                    "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png",  # Image
                    "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241",  # Ground truth
                    "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",  # Prediction
                    "",  # LM prediction
                    0.125,  # WER
                ),
                (
                    "0dfe8bcd-ed0b-453e-bf19-cc697012296e.png",  # Image
                    "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle",  # Ground truth
                    "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",  # Prediction
                    "",  # LM prediction
                    0.2667,  # WER
                ),
                (
                    "2c242f5c-e979-43c4-b6f2-a6d4815b651d.png",  # Image
                    "ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331",  # Ground truth
                    "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",  # Prediction
                    "",  # LM prediction
                    0.5,  # WER
                ),
                (
                    "ffdec445-7f14-4f5f-be44-68d0844d0df1.png",  # Image
                    "ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère",  # Ground truth
                    "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",  # Prediction
                    "",  # LM prediction
                    0.1429,  # WER
                ),
            ],
        ),
        (
            1.0,
            [
                (
                    "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png",  # Image
                    "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241",  # Ground truth
                    "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",  # Prediction
                    "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",  # LM prediction
                    0.125,  # WER
                ),
                (
                    "0dfe8bcd-ed0b-453e-bf19-cc697012296e.png",  # Image
                    "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle",  # Ground truth
                    "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",  # Prediction
                    "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",  # LM prediction
                    0.2667,  # WER
                ),
                (
                    "2c242f5c-e979-43c4-b6f2-a6d4815b651d.png",  # Image
                    "ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331",  # Ground truth
                    "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",  # Prediction
                    "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",  # LM prediction
                    0.5,  # WER
                ),
                (
                    "ffdec445-7f14-4f5f-be44-68d0844d0df1.png",  # Image
                    "ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère",  # Ground truth
                    "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",  # Prediction
                    "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",  # LM prediction
                    0.1429,  # WER
                ),
            ],
        ),
        (
            2.0,
            [
                (
                    "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png",  # Image
                    "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ12241",  # Ground truth
                    "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",  # Prediction
                    "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",  # LM prediction
                    0.125,  # WER
                ),
                (
                    "0dfe8bcd-ed0b-453e-bf19-cc697012296e.png",  # Image
                    "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle",  # Ground truth
                    "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",  # Prediction
                    "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",  # LM prediction
                    0.2667,  # WER
                ),
                (
                    "2c242f5c-e979-43c4-b6f2-a6d4815b651d.png",  # Image
                    "ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331",  # Ground truth
                    "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",  # Prediction
                    "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14331",  # LM prediction
                    0.5,  # WER
                ),
                (
                    "ffdec445-7f14-4f5f-be44-68d0844d0df1.png",  # Image
                    "ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère",  # Ground truth
                    "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",  # Prediction
                    "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",  # LM prediction
                    0.1429,  # WER
                ),
            ],
        ),
    ),
)
def test_evaluate_language_model(
    capsys, evaluate_config, language_model_weight, expected_inferences, monkeypatch
):
    # LM predictions are never used/displayed
    # We mock the `Inference` class to temporary check the results
    global nb_inferences
    nb_inferences = 0

    class MockInference(Inference):
        def __new__(cls, *args, **kwargs):
            global nb_inferences
            assert args == expected_inferences[nb_inferences]
            nb_inferences += 1

            return super().__new__(cls, *args, **kwargs)

    monkeypatch.setattr("dan.ocr.manager.training.Inference", MockInference)

    # Use the tmp_path as base folder
    evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"

    # Use a LM decoder
    evaluate_config["model"]["lm"] = {
        "path": PREDICTION_DATA_PATH / "language_model.arpa",
        "weight": language_model_weight,
    }

    evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD)

    # Check that the evaluation results are correct
    for split_name, expected_res in [
        (
            "train",
            {
                "nb_chars": 90,
                "cer": 0.1889,
                "nb_chars_no_token": 76,
                "cer_no_token": 0.2105,
                "nb_words": 15,
                "wer": 0.2667,
                "nb_words_no_punct": 15,
                "wer_no_punct": 0.2667,
                "nb_words_no_token": 15,
                "wer_no_token": 0.2667,
                "nb_tokens": 14,
                "ner": 0.0714,
                "nb_samples": 2,
            },
        ),
        (
            "val",
            {
                "nb_chars": 34,
                "cer": 0.0882,
                "nb_chars_no_token": 26,
                "cer_no_token": 0.1154,
                "nb_words": 8,
                "wer": 0.5,
                "nb_words_no_punct": 8,
                "wer_no_punct": 0.5,
                "nb_words_no_token": 8,
                "wer_no_token": 0.5,
                "nb_tokens": 8,
                "ner": 0.0,
                "nb_samples": 1,
            },
        ),
        (
            "test",
            {
                "nb_chars": 36,
                "cer": 0.0278,
                "nb_chars_no_token": 30,
                "cer_no_token": 0.0333,
                "nb_words": 7,
                "wer": 0.1429,
                "nb_words_no_punct": 7,
                "wer_no_punct": 0.1429,
                "nb_words_no_token": 7,
                "wer_no_token": 0.1429,
                "nb_tokens": 6,
                "ner": 0.0,
                "nb_samples": 1,
            },
        ),
    ]:
        filename = (
            evaluate_config["training"]["output_folder"]
            / "results"
            / f"predict_training-{split_name}_1685.yaml"
        )

        with filename.open() as f:
            assert {
                metric: value
                for metric, value in yaml.safe_load(f).items()
                # Remove the times from the results as they vary
                if "time" not in metric
            } == expected_res

    # Remove results files
    shutil.rmtree(evaluate_config["training"]["output_folder"] / "results")

    # Check the metrics Markdown table
    captured_std = capsys.readouterr()
    last_printed_lines = captured_std.out.split("\n")[10:]
    assert (
        "\n".join(last_printed_lines)
        == Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()
    )