Skip to content
Snippets Groups Projects
test_prediction.py 22.16 KiB
# -*- coding: utf-8 -*-

import json
import shutil

import pytest

from dan.predict.prediction import DAN
from dan.predict.prediction import run as run_prediction


@pytest.mark.parametrize(
    "image_name, expected_prediction",
    (
        (
            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png",
            {"text": ["ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"]},
        ),
        (
            "0dfe8bcd-ed0b-453e-bf19-cc697012296e.png",
            {"text": ["ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376"]},
        ),
        (
            "2c242f5c-e979-43c4-b6f2-a6d4815b651d.png",
            {"text": ["Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"]},
        ),
        (
            "ffdec445-7f14-4f5f-be44-68d0844d0df1.png",
            {"text": ["ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"]},
        ),
    ),
)
def test_predict(
    image_name,
    expected_prediction,
    prediction_data_path,
):
    device = "cpu"

    dan_model = DAN(device)
    dan_model.load(
        prediction_data_path / "popp_line_model.pt",
        prediction_data_path / "parameters.yml",
        prediction_data_path / "charset.pkl",
        mode="eval",
    )

    image_path = prediction_data_path / "images" / image_name
    _, image = dan_model.preprocess(str(image_path))

    input_tensor = image.unsqueeze(0)
    input_tensor = input_tensor.to(device)
    input_sizes = [image.shape[1:]]

    prediction = dan_model.predict(input_tensor, input_sizes)

    assert prediction == expected_prediction


@pytest.mark.parametrize(
    "image_name, confidence_score, temperature, expected_prediction",
    (
        (
            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
            None,
            1.0,
            {"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"},
        ),
        (
            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
            ["word"],
            1.0,
            {
                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                "confidences": {
                    "by ner token": [
                        {"text": "ⓈBellisson ", "confidence_ner": "1.0"},
                        {"text": "ⒻGeorges ", "confidence_ner": "1.0"},
                        {"text": "Ⓑ91 ", "confidence_ner": "1.0"},
                        {"text": "ⓁP ", "confidence_ner": "1.0"},
                        {"text": "ⒸM ", "confidence_ner": "1.0"},
                        {"text": "ⓀCh ", "confidence_ner": "1.0"},
                        {"text": "ⓄPlombier ", "confidence_ner": "1.0"},
                        {"text": "ⓅPatron?12241", "confidence_ner": "1.0"},
                    ],
                    "total": 1.0,
                    "word": [
                        {"text": "ⓈBellisson", "confidence": 1.0},
                        {"text": "ⒻGeorges", "confidence": 1.0},
                        {"text": "Ⓑ91", "confidence": 1.0},
                        {"text": "ⓁP", "confidence": 1.0},
                        {"text": "ⒸM", "confidence": 1.0},
                        {"text": "ⓀCh", "confidence": 1.0},
                        {"text": "ⓄPlombier", "confidence": 1.0},
                        {"text": "ⓅPatron?12241", "confidence": 1.0},
                    ],
                },
            },
        ),
        (
            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
            ["word"],
            3.5,
            {
                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                "confidences": {
                    "by ner token": [
                        {"text": "ⓈBellisson ", "confidence_ner": "0.92"},
                        {"text": "ⒻGeorges ", "confidence_ner": "0.94"},
                        {"text": "Ⓑ91 ", "confidence_ner": "0.93"},
                        {"text": "ⓁP ", "confidence_ner": "0.92"},
                        {"text": "ⒸM ", "confidence_ner": "0.93"},
                        {"text": "ⓀCh ", "confidence_ner": "0.95"},
                        {"text": "ⓄPlombier ", "confidence_ner": "0.93"},
                        {"text": "ⓅPatron?12241", "confidence_ner": "0.93"},
                    ],
                    "total": 0.93,
                    "word": [
                        {"text": "ⓈBellisson", "confidence": 0.93},
                        {"text": "ⒻGeorges", "confidence": 0.94},
                        {"text": "Ⓑ91", "confidence": 0.92},
                        {"text": "ⓁP", "confidence": 0.94},
                        {"text": "ⒸM", "confidence": 0.93},
                        {"text": "ⓀCh", "confidence": 0.96},
                        {"text": "ⓄPlombier", "confidence": 0.94},
                        {"text": "ⓅPatron?12241", "confidence": 0.93},
                    ],
                },
            },
        ),
        (
            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
            ["line"],
            1.0,
            {
                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                "confidences": {
                    "by ner token": [
                        {"text": "ⓈBellisson ", "confidence_ner": "1.0"},
                        {"text": "ⒻGeorges ", "confidence_ner": "1.0"},
                        {"text": "Ⓑ91 ", "confidence_ner": "1.0"},
                        {"text": "ⓁP ", "confidence_ner": "1.0"},
                        {"text": "ⒸM ", "confidence_ner": "1.0"},
                        {"text": "ⓀCh ", "confidence_ner": "1.0"},
                        {"text": "ⓄPlombier ", "confidence_ner": "1.0"},
                        {"text": "ⓅPatron?12241", "confidence_ner": "1.0"},
                    ],
                    "total": 1.0,
                    "line": [
                        {
                            "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                            "confidence": 1.0,
                        }
                    ],
                },
            },
        ),
        (
            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
            ["line"],
            3.5,
            {
                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                "confidences": {
                    "by ner token": [
                        {"text": "ⓈBellisson ", "confidence_ner": "0.92"},
                        {"text": "ⒻGeorges ", "confidence_ner": "0.94"},
                        {"text": "Ⓑ91 ", "confidence_ner": "0.93"},
                        {"text": "ⓁP ", "confidence_ner": "0.92"},
                        {"text": "ⒸM ", "confidence_ner": "0.93"},
                        {"text": "ⓀCh ", "confidence_ner": "0.95"},
                        {"text": "ⓄPlombier ", "confidence_ner": "0.93"},
                        {"text": "ⓅPatron?12241", "confidence_ner": "0.93"},
                    ],
                    "total": 0.93,
                    "line": [
                        {
                            "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                            "confidence": 0.93,
                        }
                    ],
                },
            },
        ),
        (
            "0dfe8bcd-ed0b-453e-bf19-cc697012296e",
            None,
            1.0,
            {"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376"},
        ),
        (
            "0dfe8bcd-ed0b-453e-bf19-cc697012296e",
            ["char", "word", "line"],
            1.0,
            {
                "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
                "confidences": {
                    "by ner token": [
                        {"text": "ⓈTemplié ", "confidence_ner": "0.98"},
                        {"text": "ⒻMarcelle ", "confidence_ner": "1.0"},
                        {"text": "Ⓑ93 ", "confidence_ner": "1.0"},
                        {"text": "ⓁS ", "confidence_ner": "1.0"},
                        {"text": "Ⓚch ", "confidence_ner": "1.0"},
                        {"text": "ⓄE dactylo ", "confidence_ner": "1.0"},
                        {"text": "Ⓟ18376", "confidence_ner": "1.0"},
                    ],
                    "total": 1.0,
                    "char": [
                        {"text": "", "confidence": 1.0},
                        {"text": "T", "confidence": 1.0},
                        {"text": "e", "confidence": 1.0},
                        {"text": "m", "confidence": 1.0},
                        {"text": "p", "confidence": 1.0},
                        {"text": "l", "confidence": 1.0},
                        {"text": "i", "confidence": 1.0},
                        {"text": "é", "confidence": 0.85},
                        {"text": " ", "confidence": 1.0},
                        {"text": "", "confidence": 1.0},
                        {"text": "M", "confidence": 1.0},
                        {"text": "a", "confidence": 1.0},
                        {"text": "r", "confidence": 1.0},
                        {"text": "c", "confidence": 1.0},
                        {"text": "e", "confidence": 1.0},
                        {"text": "l", "confidence": 1.0},
                        {"text": "l", "confidence": 1.0},
                        {"text": "e", "confidence": 1.0},
                        {"text": " ", "confidence": 1.0},
                        {"text": "", "confidence": 1.0},
                        {"text": "9", "confidence": 1.0},
                        {"text": "3", "confidence": 1.0},
                        {"text": " ", "confidence": 1.0},
                        {"text": "", "confidence": 1.0},
                        {"text": "S", "confidence": 1.0},
                        {"text": " ", "confidence": 1.0},
                        {"text": "", "confidence": 1.0},
                        {"text": "c", "confidence": 1.0},
                        {"text": "h", "confidence": 1.0},
                        {"text": " ", "confidence": 1.0},
                        {"text": "", "confidence": 1.0},
                        {"text": "E", "confidence": 1.0},
                        {"text": " ", "confidence": 1.0},
                        {"text": "d", "confidence": 1.0},
                        {"text": "a", "confidence": 1.0},
                        {"text": "c", "confidence": 1.0},
                        {"text": "t", "confidence": 1.0},
                        {"text": "y", "confidence": 1.0},
                        {"text": "l", "confidence": 1.0},
                        {"text": "o", "confidence": 1.0},
                        {"text": " ", "confidence": 1.0},
                        {"text": "", "confidence": 1.0},
                        {"text": "1", "confidence": 1.0},
                        {"text": "8", "confidence": 1.0},
                        {"text": "3", "confidence": 1.0},
                        {"text": "7", "confidence": 1.0},
                        {"text": "6", "confidence": 1.0},
                    ],
                    "word": [
                        {"text": "ⓈTemplié", "confidence": 0.98},
                        {"text": "ⒻMarcelle", "confidence": 1.0},
                        {"text": "Ⓑ93", "confidence": 1.0},
                        {"text": "ⓁS", "confidence": 1.0},
                        {"text": "Ⓚch", "confidence": 1.0},
                        {"text": "ⓄE", "confidence": 1.0},
                        {"text": "dactylo", "confidence": 1.0},
                        {"text": "Ⓟ18376", "confidence": 1.0},
                    ],
                    "line": [
                        {
                            "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
                            "confidence": 1.0,
                        }
                    ],
                },
            },
        ),
        (
            "2c242f5c-e979-43c4-b6f2-a6d4815b651d",
            False,
            1.0,
            {"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"},
        ),
        (
            "ffdec445-7f14-4f5f-be44-68d0844d0df1",
            False,
            1.0,
            {"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
        ),
    ),
)
def test_run_prediction(
    image_name,
    confidence_score,
    temperature,
    expected_prediction,
    prediction_data_path,
    tmp_path,
):
    run_prediction(
        image=(prediction_data_path / "images" / image_name).with_suffix(".png"),
        image_dir=None,
        model=prediction_data_path / "popp_line_model.pt",
        parameters=prediction_data_path / "parameters.yml",
        charset=prediction_data_path / "charset.pkl",
        output=tmp_path,
        confidence_score=True if confidence_score else False,
        confidence_score_levels=confidence_score if confidence_score else [],
        attention_map=False,
        attention_map_level=None,
        attention_map_scale=0.5,
        word_separators=[" ", "\n"],
        line_separators=["\n"],
        temperature=temperature,
        predict_objects=False,
        threshold_method="otsu",
        threshold_value=0,
        image_extension=None,
        gpu_device=None,
        batch_size=1,
        tokens=prediction_data_path / "tokens.yml",
    )

    prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text())
    assert prediction == expected_prediction


@pytest.mark.parametrize(
    "image_names, confidence_score, temperature, expected_predictions",
    (
        (
            ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
            None,
            1.0,
            [{"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"}],
        ),
        (
            ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
            ["word"],
            1.0,
            [
                {
                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                    "confidences": {
                        "by ner token": [
                            {"text": "ⓈBellisson ", "confidence_ner": "1.0"},
                            {"text": "ⒻGeorges ", "confidence_ner": "1.0"},
                            {"text": "Ⓑ91 ", "confidence_ner": "1.0"},
                            {"text": "ⓁP ", "confidence_ner": "1.0"},
                            {"text": "ⒸM ", "confidence_ner": "1.0"},
                            {"text": "ⓀCh ", "confidence_ner": "1.0"},
                            {"text": "ⓄPlombier ", "confidence_ner": "1.0"},
                            {"text": "ⓅPatron?12241", "confidence_ner": "1.0"},
                        ],
                        "total": 1.0,
                        "word": [
                            {"text": "ⓈBellisson", "confidence": 1.0},
                            {"text": "ⒻGeorges", "confidence": 1.0},
                            {"text": "Ⓑ91", "confidence": 1.0},
                            {"text": "ⓁP", "confidence": 1.0},
                            {"text": "ⒸM", "confidence": 1.0},
                            {"text": "ⓀCh", "confidence": 1.0},
                            {"text": "ⓄPlombier", "confidence": 1.0},
                            {"text": "ⓅPatron?12241", "confidence": 1.0},
                        ],
                    },
                }
            ],
        ),
        (
            [
                "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
                "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
            ],
            ["word"],
            1.0,
            [
                {
                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                    "confidences": {
                        "by ner token": [
                            {"text": "ⓈBellisson ", "confidence_ner": "1.0"},
                            {"text": "ⒻGeorges ", "confidence_ner": "1.0"},
                            {"text": "Ⓑ91 ", "confidence_ner": "1.0"},
                            {"text": "ⓁP ", "confidence_ner": "1.0"},
                            {"text": "ⒸM ", "confidence_ner": "1.0"},
                            {"text": "ⓀCh ", "confidence_ner": "1.0"},
                            {"text": "ⓄPlombier ", "confidence_ner": "1.0"},
                            {"text": "ⓅPatron?12241", "confidence_ner": "1.0"},
                        ],
                        "total": 1.0,
                        "word": [
                            {"text": "ⓈBellisson", "confidence": 1.0},
                            {"text": "ⒻGeorges", "confidence": 1.0},
                            {"text": "Ⓑ91", "confidence": 1.0},
                            {"text": "ⓁP", "confidence": 1.0},
                            {"text": "ⒸM", "confidence": 1.0},
                            {"text": "ⓀCh", "confidence": 1.0},
                            {"text": "ⓄPlombier", "confidence": 1.0},
                            {"text": "ⓅPatron?12241", "confidence": 1.0},
                        ],
                    },
                },
                {
                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                    "confidences": {
                        "by ner token": [
                            {"text": "ⓈBellisson ", "confidence_ner": "1.0"},
                            {"text": "ⒻGeorges ", "confidence_ner": "1.0"},
                            {"text": "Ⓑ91 ", "confidence_ner": "1.0"},
                            {"text": "ⓁP ", "confidence_ner": "1.0"},
                            {"text": "ⒸM ", "confidence_ner": "1.0"},
                            {"text": "ⓀCh ", "confidence_ner": "1.0"},
                            {"text": "ⓄPlombier ", "confidence_ner": "1.0"},
                            {"text": "ⓅPatron?12241", "confidence_ner": "1.0"},
                        ],
                        "total": 1.0,
                        "word": [
                            {"text": "ⓈBellisson", "confidence": 1.0},
                            {"text": "ⒻGeorges", "confidence": 1.0},
                            {"text": "Ⓑ91", "confidence": 1.0},
                            {"text": "ⓁP", "confidence": 1.0},
                            {"text": "ⒸM", "confidence": 1.0},
                            {"text": "ⓀCh", "confidence": 1.0},
                            {"text": "ⓄPlombier", "confidence": 1.0},
                            {"text": "ⓅPatron?12241", "confidence": 1.0},
                        ],
                    },
                },
            ],
        ),
        (
            ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
            ["word"],
            1.0,
            [
                {
                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                    "confidences": {
                        "by ner token": [
                            {"text": "ⓈBellisson ", "confidence_ner": "1.0"},
                            {"text": "ⒻGeorges ", "confidence_ner": "1.0"},
                            {"text": "Ⓑ91 ", "confidence_ner": "1.0"},
                            {"text": "ⓁP ", "confidence_ner": "1.0"},
                            {"text": "ⒸM ", "confidence_ner": "1.0"},
                            {"text": "ⓀCh ", "confidence_ner": "1.0"},
                            {"text": "ⓄPlombier ", "confidence_ner": "1.0"},
                            {"text": "ⓅPatron?12241", "confidence_ner": "1.0"},
                        ],
                        "total": 1.0,
                        "word": [
                            {"text": "ⓈBellisson", "confidence": 1.0},
                            {"text": "ⒻGeorges", "confidence": 1.0},
                            {"text": "Ⓑ91", "confidence": 1.0},
                            {"text": "ⓁP", "confidence": 1.0},
                            {"text": "ⒸM", "confidence": 1.0},
                            {"text": "ⓀCh", "confidence": 1.0},
                            {"text": "ⓄPlombier", "confidence": 1.0},
                            {"text": "ⓅPatron?12241", "confidence": 1.0},
                        ],
                    },
                }
            ],
        ),
        (
            [
                "2c242f5c-e979-43c4-b6f2-a6d4815b651d",
                "ffdec445-7f14-4f5f-be44-68d0844d0df1",
            ],
            False,
            1.0,
            [
                {"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"},
                {"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
            ],
        ),
    ),
)
@pytest.mark.parametrize("batch_size", [1, 2])
def test_run_prediction_batch(
    image_names,
    confidence_score,
    temperature,
    expected_predictions,
    prediction_data_path,
    batch_size,
    tmp_path,
):
    # Make tmpdir and copy needed images inside
    image_dir = tmp_path / "images"
    image_dir.mkdir()
    for image_name in image_names:
        shutil.copyfile(
            (prediction_data_path / "images" / image_name).with_suffix(".png"),
            (image_dir / image_name).with_suffix(".png"),
        )

    run_prediction(
        image=None,
        image_dir=image_dir,
        model=prediction_data_path / "popp_line_model.pt",
        parameters=prediction_data_path / "parameters.yml",
        charset=prediction_data_path / "charset.pkl",
        output=tmp_path,
        confidence_score=True if confidence_score else False,
        confidence_score_levels=confidence_score if confidence_score else [],
        attention_map=False,
        attention_map_level=None,
        attention_map_scale=0.5,
        word_separators=[" ", "\n"],
        line_separators=["\n"],
        temperature=temperature,
        predict_objects=False,
        threshold_method="otsu",
        threshold_value=0,
        image_extension=".png",
        gpu_device=None,
        batch_size=batch_size,
        tokens=prediction_data_path / "tokens.yml",
    )

    for image_name, expected_prediction in zip(image_names, expected_predictions):
        prediction = json.loads(
            (tmp_path / image_name).with_suffix(".json").read_text()
        )
        assert prediction == expected_prediction