Skip to content
Snippets Groups Projects
test_prediction.py 27.25 KiB
# -*- coding: utf-8 -*-

import json
import shutil

import numpy as np
import pytest
import yaml

from dan.ocr.predict.attention import Level
from dan.ocr.predict.inference import DAN
from dan.ocr.predict.inference import run as run_prediction
from dan.utils import parse_tokens, read_yaml
from tests import FIXTURES

PREDICTION_DATA_PATH = FIXTURES / "prediction"


@pytest.mark.parametrize(
    "image_name, expected_prediction",
    (
        (
            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png",
            {"text": ["ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"]},
        ),
        (
            "0dfe8bcd-ed0b-453e-bf19-cc697012296e.png",
            {"text": ["ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376"]},
        ),
        (
            "2c242f5c-e979-43c4-b6f2-a6d4815b651d.png",
            {"text": ["Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"]},
        ),
        (
            "ffdec445-7f14-4f5f-be44-68d0844d0df1.png",
            {"text": ["ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"]},
        ),
    ),
)
@pytest.mark.parametrize("normalize", (True, False))
def test_predict(image_name, expected_prediction, normalize, tmp_path):
    # Update mean/std in parameters.yml
    model_path = tmp_path / "models"
    model_path.mkdir(exist_ok=True)

    shutil.copyfile(
        PREDICTION_DATA_PATH / "model.pt",
        model_path / "model.pt",
    )
    shutil.copyfile(
        PREDICTION_DATA_PATH / "charset.pkl",
        model_path / "charset.pkl",
    )

    params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml")
    if not normalize:
        del params["parameters"]["mean"]
        del params["parameters"]["std"]
    yaml.dump(params, (model_path / "parameters.yml").open("w"))

    device = "cpu"

    dan_model = DAN(device)
    dan_model.load(path=model_path, mode="eval")

    image_path = PREDICTION_DATA_PATH / "images" / image_name
    _, image = dan_model.preprocess(str(image_path))

    input_tensor = image.unsqueeze(0)
    input_tensor = input_tensor.to(device)
    input_sizes = [image.shape[1:]]

    prediction = dan_model.predict(input_tensor, input_sizes)

    assert prediction == expected_prediction


@pytest.mark.parametrize(
    "image_name, confidence_score, temperature, expected_prediction",
    (
        (
            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
            None,
            1.0,
            {
                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                "language_model": {},
                "confidences": {},
            },
        ),
        (
            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
            [Level.Word],
            1.0,
            {
                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                "language_model": {},
                "confidences": {
                    "total": 1.0,
                    "word": [
                        {"text": "ⓈBellisson", "confidence": 1.0},
                        {"text": "ⒻGeorges", "confidence": 1.0},
                        {"text": "Ⓑ91", "confidence": 1.0},
                        {"text": "ⓁP", "confidence": 1.0},
                        {"text": "ⒸM", "confidence": 1.0},
                        {"text": "ⓀCh", "confidence": 1.0},
                        {"text": "ⓄPlombier", "confidence": 1.0},
                        {"text": "ⓅPatron?12241", "confidence": 1.0},
                    ],
                },
            },
        ),
        (
            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
            [Level.NER, Level.Word],
            3.5,
            {
                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                "language_model": {},
                "confidences": {
                    "total": 0.93,
                    "ner": [
                        {"text": "ⓈBellisson ", "confidence": 0.92},
                        {"text": "ⒻGeorges ", "confidence": 0.94},
                        {"text": "Ⓑ91 ", "confidence": 0.93},
                        {"text": "ⓁP ", "confidence": 0.92},
                        {"text": "ⒸM ", "confidence": 0.93},
                        {"text": "ⓀCh ", "confidence": 0.95},
                        {"text": "ⓄPlombier ", "confidence": 0.93},
                        {"text": "ⓅPatron?12241", "confidence": 0.93},
                    ],
                    "word": [
                        {"text": "ⓈBellisson", "confidence": 0.93},
                        {"text": "ⒻGeorges", "confidence": 0.94},
                        {"text": "Ⓑ91", "confidence": 0.92},
                        {"text": "ⓁP", "confidence": 0.94},
                        {"text": "ⒸM", "confidence": 0.93},
                        {"text": "ⓀCh", "confidence": 0.96},
                        {"text": "ⓄPlombier", "confidence": 0.94},
                        {"text": "ⓅPatron?12241", "confidence": 0.93},
                    ],
                },
            },
        ),
        (
            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
            [Level.Line],
            1.0,
            {
                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                "language_model": {},
                "confidences": {
                    "total": 1.0,
                    "line": [
                        {
                            "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                            "confidence": 1.0,
                        }
                    ],
                },
            },
        ),
        (
            "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
            [Level.NER, Level.Line],
            3.5,
            {
                "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                "language_model": {},
                "confidences": {
                    "total": 0.93,
                    "ner": [
                        {"text": "ⓈBellisson ", "confidence": 0.92},
                        {"text": "ⒻGeorges ", "confidence": 0.94},
                        {"text": "Ⓑ91 ", "confidence": 0.93},
                        {"text": "ⓁP ", "confidence": 0.92},
                        {"text": "ⒸM ", "confidence": 0.93},
                        {"text": "ⓀCh ", "confidence": 0.95},
                        {"text": "ⓄPlombier ", "confidence": 0.93},
                        {"text": "ⓅPatron?12241", "confidence": 0.93},
                    ],
                    "line": [
                        {
                            "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                            "confidence": 0.93,
                        }
                    ],
                },
            },
        ),
        (
            "0dfe8bcd-ed0b-453e-bf19-cc697012296e",
            None,
            1.0,
            {
                "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
                "language_model": {},
                "confidences": {},
            },
        ),
        (
            "0dfe8bcd-ed0b-453e-bf19-cc697012296e",
            [Level.NER, Level.Char, Level.Word, Level.Line],
            1.0,
            {
                "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
                "language_model": {},
                "confidences": {
                    "total": 1.0,
                    "ner": [
                        {"text": "ⓈTemplié ", "confidence": 0.98},
                        {"text": "ⒻMarcelle ", "confidence": 1.0},
                        {"text": "Ⓑ93 ", "confidence": 1.0},
                        {"text": "ⓁS ", "confidence": 1.0},
                        {"text": "Ⓚch ", "confidence": 1.0},
                        {"text": "ⓄE dactylo ", "confidence": 1.0},
                        {"text": "Ⓟ18376", "confidence": 1.0},
                    ],
                    "char": [
                        {"text": "", "confidence": 1.0},
                        {"text": "T", "confidence": 1.0},
                        {"text": "e", "confidence": 1.0},
                        {"text": "m", "confidence": 1.0},
                        {"text": "p", "confidence": 1.0},
                        {"text": "l", "confidence": 1.0},
                        {"text": "i", "confidence": 1.0},
                        {"text": "é", "confidence": 0.85},
                        {"text": " ", "confidence": 1.0},
                        {"text": "", "confidence": 1.0},
                        {"text": "M", "confidence": 1.0},
                        {"text": "a", "confidence": 1.0},
                        {"text": "r", "confidence": 1.0},
                        {"text": "c", "confidence": 1.0},
                        {"text": "e", "confidence": 1.0},
                        {"text": "l", "confidence": 1.0},
                        {"text": "l", "confidence": 1.0},
                        {"text": "e", "confidence": 1.0},
                        {"text": " ", "confidence": 1.0},
                        {"text": "", "confidence": 1.0},
                        {"text": "9", "confidence": 1.0},
                        {"text": "3", "confidence": 1.0},
                        {"text": " ", "confidence": 1.0},
                        {"text": "", "confidence": 1.0},
                        {"text": "S", "confidence": 1.0},
                        {"text": " ", "confidence": 1.0},
                        {"text": "", "confidence": 1.0},
                        {"text": "c", "confidence": 1.0},
                        {"text": "h", "confidence": 1.0},
                        {"text": " ", "confidence": 1.0},
                        {"text": "", "confidence": 1.0},
                        {"text": "E", "confidence": 1.0},
                        {"text": " ", "confidence": 1.0},
                        {"text": "d", "confidence": 1.0},
                        {"text": "a", "confidence": 1.0},
                        {"text": "c", "confidence": 1.0},
                        {"text": "t", "confidence": 1.0},
                        {"text": "y", "confidence": 1.0},
                        {"text": "l", "confidence": 1.0},
                        {"text": "o", "confidence": 1.0},
                        {"text": " ", "confidence": 1.0},
                        {"text": "", "confidence": 1.0},
                        {"text": "1", "confidence": 1.0},
                        {"text": "8", "confidence": 1.0},
                        {"text": "3", "confidence": 1.0},
                        {"text": "7", "confidence": 1.0},
                        {"text": "6", "confidence": 1.0},
                    ],
                    "word": [
                        {"text": "ⓈTemplié", "confidence": 0.98},
                        {"text": "ⒻMarcelle", "confidence": 1.0},
                        {"text": "Ⓑ93", "confidence": 1.0},
                        {"text": "ⓁS", "confidence": 1.0},
                        {"text": "Ⓚch", "confidence": 1.0},
                        {"text": "ⓄE", "confidence": 1.0},
                        {"text": "dactylo", "confidence": 1.0},
                        {"text": "Ⓟ18376", "confidence": 1.0},
                    ],
                    "line": [
                        {
                            "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
                            "confidence": 1.0,
                        }
                    ],
                },
            },
        ),
        (
            "2c242f5c-e979-43c4-b6f2-a6d4815b651d",
            False,
            1.0,
            {
                "text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
                "language_model": {},
                "confidences": {},
            },
        ),
        (
            "ffdec445-7f14-4f5f-be44-68d0844d0df1",
            False,
            1.0,
            {
                "text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
                "language_model": {},
                "confidences": {},
            },
        ),
    ),
)
def test_run_prediction(
    image_name,
    confidence_score,
    temperature,
    expected_prediction,
    tmp_path,
):
    # Make tmpdir and copy needed image inside
    image_dir = tmp_path / "images"
    image_dir.mkdir()
    shutil.copyfile(
        (PREDICTION_DATA_PATH / "images" / image_name).with_suffix(".png"),
        (image_dir / image_name).with_suffix(".png"),
    )

    run_prediction(
        image_dir=image_dir,
        model=PREDICTION_DATA_PATH,
        output=tmp_path,
        confidence_score=True if confidence_score else False,
        confidence_score_levels=confidence_score if confidence_score else [],
        attention_map=False,
        attention_map_level=None,
        attention_map_scale=0.5,
        word_separators=[" ", "\n"],
        line_separators=["\n"],
        temperature=temperature,
        predict_objects=False,
        max_object_height=None,
        image_extension=".png",
        gpu_device=None,
        batch_size=1,
        tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"),
        start_token=None,
        use_language_model=False,
    )

    prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text())
    assert prediction == expected_prediction


@pytest.mark.parametrize(
    "image_names, confidence_score, temperature, expected_predictions",
    (
        (
            ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
            None,
            1.0,
            [
                {
                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                    "language_model": {},
                    "confidences": {},
                }
            ],
        ),
        (
            ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
            [Level.Word],
            1.0,
            [
                {
                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                    "language_model": {},
                    "confidences": {
                        "total": 1.0,
                        "word": [
                            {"text": "ⓈBellisson", "confidence": 1.0},
                            {"text": "ⒻGeorges", "confidence": 1.0},
                            {"text": "Ⓑ91", "confidence": 1.0},
                            {"text": "ⓁP", "confidence": 1.0},
                            {"text": "ⒸM", "confidence": 1.0},
                            {"text": "ⓀCh", "confidence": 1.0},
                            {"text": "ⓄPlombier", "confidence": 1.0},
                            {"text": "ⓅPatron?12241", "confidence": 1.0},
                        ],
                    },
                }
            ],
        ),
        (
            [
                "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
                "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
            ],
            [Level.NER, Level.Word],
            1.0,
            [
                {
                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                    "language_model": {},
                    "confidences": {
                        "total": 1.0,
                        "ner": [
                            {"text": "ⓈBellisson ", "confidence": 1.0},
                            {"text": "ⒻGeorges ", "confidence": 1.0},
                            {"text": "Ⓑ91 ", "confidence": 1.0},
                            {"text": "ⓁP ", "confidence": 1.0},
                            {"text": "ⒸM ", "confidence": 1.0},
                            {"text": "ⓀCh ", "confidence": 1.0},
                            {"text": "ⓄPlombier ", "confidence": 1.0},
                            {"text": "ⓅPatron?12241", "confidence": 1.0},
                        ],
                        "word": [
                            {"text": "ⓈBellisson", "confidence": 1.0},
                            {"text": "ⒻGeorges", "confidence": 1.0},
                            {"text": "Ⓑ91", "confidence": 1.0},
                            {"text": "ⓁP", "confidence": 1.0},
                            {"text": "ⒸM", "confidence": 1.0},
                            {"text": "ⓀCh", "confidence": 1.0},
                            {"text": "ⓄPlombier", "confidence": 1.0},
                            {"text": "ⓅPatron?12241", "confidence": 1.0},
                        ],
                    },
                },
                {
                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                    "language_model": {},
                    "confidences": {
                        "total": 1.0,
                        "ner": [
                            {"text": "ⓈBellisson ", "confidence": 1.0},
                            {"text": "ⒻGeorges ", "confidence": 1.0},
                            {"text": "Ⓑ91 ", "confidence": 1.0},
                            {"text": "ⓁP ", "confidence": 1.0},
                            {"text": "ⒸM ", "confidence": 1.0},
                            {"text": "ⓀCh ", "confidence": 1.0},
                            {"text": "ⓄPlombier ", "confidence": 1.0},
                            {"text": "ⓅPatron?12241", "confidence": 1.0},
                        ],
                        "word": [
                            {"text": "ⓈBellisson", "confidence": 1.0},
                            {"text": "ⒻGeorges", "confidence": 1.0},
                            {"text": "Ⓑ91", "confidence": 1.0},
                            {"text": "ⓁP", "confidence": 1.0},
                            {"text": "ⒸM", "confidence": 1.0},
                            {"text": "ⓀCh", "confidence": 1.0},
                            {"text": "ⓄPlombier", "confidence": 1.0},
                            {"text": "ⓅPatron?12241", "confidence": 1.0},
                        ],
                    },
                },
            ],
        ),
        (
            ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
            [Level.Word],
            1.0,
            [
                {
                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                    "language_model": {},
                    "confidences": {
                        "total": 1.0,
                        "word": [
                            {"text": "ⓈBellisson", "confidence": 1.0},
                            {"text": "ⒻGeorges", "confidence": 1.0},
                            {"text": "Ⓑ91", "confidence": 1.0},
                            {"text": "ⓁP", "confidence": 1.0},
                            {"text": "ⒸM", "confidence": 1.0},
                            {"text": "ⓀCh", "confidence": 1.0},
                            {"text": "ⓄPlombier", "confidence": 1.0},
                            {"text": "ⓅPatron?12241", "confidence": 1.0},
                        ],
                    },
                }
            ],
        ),
        (
            [
                "2c242f5c-e979-43c4-b6f2-a6d4815b651d",
                "ffdec445-7f14-4f5f-be44-68d0844d0df1",
            ],
            False,
            1.0,
            [
                {
                    "text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
                    "language_model": {},
                    "confidences": {},
                },
                {
                    "text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
                    "language_model": {},
                    "confidences": {},
                },
            ],
        ),
    ),
)
@pytest.mark.parametrize("batch_size", [1, 2])
def test_run_prediction_batch(
    image_names,
    confidence_score,
    temperature,
    expected_predictions,
    batch_size,
    tmp_path,
):
    # Make tmpdir and copy needed images inside
    image_dir = tmp_path / "images"
    image_dir.mkdir()
    for image_name in image_names:
        shutil.copyfile(
            (PREDICTION_DATA_PATH / "images" / image_name).with_suffix(".png"),
            (image_dir / image_name).with_suffix(".png"),
        )

    run_prediction(
        image_dir=image_dir,
        model=PREDICTION_DATA_PATH,
        output=tmp_path,
        confidence_score=True if confidence_score else False,
        confidence_score_levels=confidence_score if confidence_score else [],
        attention_map=False,
        attention_map_level=None,
        attention_map_scale=0.5,
        word_separators=[" ", "\n"],
        line_separators=["\n"],
        temperature=temperature,
        predict_objects=False,
        max_object_height=None,
        image_extension=".png",
        gpu_device=None,
        batch_size=batch_size,
        tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"),
        start_token=None,
        use_language_model=False,
    )

    for image_name, expected_prediction in zip(image_names, expected_predictions):
        prediction = json.loads(
            (tmp_path / image_name).with_suffix(".json").read_text()
        )
        assert prediction == expected_prediction


@pytest.mark.parametrize(
    "image_names, language_model_weight, expected_predictions",
    (
        (
            [
                "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
                "0dfe8bcd-ed0b-453e-bf19-cc697012296e",
                "2c242f5c-e979-43c4-b6f2-a6d4815b651d",
                "ffdec445-7f14-4f5f-be44-68d0844d0df1",
            ],
            1.0,
            [
                {
                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                    "language_model": {
                        "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                        "confidence": 0.92,
                    },
                },
                {
                    "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
                    "language_model": {
                        "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
                        "confidence": 0.88,
                    },
                },
                {
                    "text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
                    "language_model": {
                        "text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
                        "confidence": 0.86,
                    },
                },
                {
                    "text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
                    "language_model": {
                        "text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
                        "confidence": 0.89,
                    },
                },
            ],
        ),
        (
            [
                "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
                "0dfe8bcd-ed0b-453e-bf19-cc697012296e",
                "2c242f5c-e979-43c4-b6f2-a6d4815b651d",
                "ffdec445-7f14-4f5f-be44-68d0844d0df1",
            ],
            2.0,
            [
                {
                    "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                    "language_model": {
                        "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
                        "confidence": 0.90,
                    },
                },
                {
                    "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
                    "language_model": {
                        "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
                        "confidence": 0.84,
                    },
                },
                {
                    "text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
                    "language_model": {
                        "text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14331",
                        "confidence": 0.83,
                    },
                },
                {
                    "text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
                    "language_model": {
                        "text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
                        "confidence": 0.86,
                    },
                },
            ],
        ),
        (
            [
                "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
                "0dfe8bcd-ed0b-453e-bf19-cc697012296e",
                "2c242f5c-e979-43c4-b6f2-a6d4815b651d",
                "ffdec445-7f14-4f5f-be44-68d0844d0df1",
            ],
            0.0,
            [
                {"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"},
                {"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376"},
                {"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"},
                {"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
            ],
        ),
    ),
)
def test_run_prediction_language_model(
    image_names,
    language_model_weight,
    expected_predictions,
    tmp_path,
):
    # Make tmpdir and copy needed images inside
    image_dir = tmp_path / "images"
    image_dir.mkdir()
    for image_name in image_names:
        shutil.copyfile(
            (PREDICTION_DATA_PATH / "images" / image_name).with_suffix(".png"),
            (image_dir / image_name).with_suffix(".png"),
        )

    # Update language_model_weight in parameters.yml
    model_path = tmp_path / "models"
    model_path.mkdir(exist_ok=True)

    shutil.copyfile(
        PREDICTION_DATA_PATH / "model.pt",
        model_path / "model.pt",
    )
    shutil.copyfile(
        PREDICTION_DATA_PATH / "charset.pkl",
        model_path / "charset.pkl",
    )

    params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml")
    params["parameters"]["language_model"]["weight"] = language_model_weight
    yaml.dump(params, (model_path / "parameters.yml").open("w"))

    run_prediction(
        image_dir=image_dir,
        model=model_path,
        output=tmp_path,
        confidence_score=False,
        confidence_score_levels=[],
        attention_map=[],
        attention_map_level=None,
        attention_map_scale=0.5,
        word_separators=[" ", "\n"],
        line_separators=["\n"],
        temperature=1.0,
        predict_objects=False,
        max_object_height=None,
        image_extension=".png",
        gpu_device=None,
        batch_size=1,
        tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"),
        start_token=None,
        use_language_model=True,
    )

    for image_name, expected_prediction in zip(image_names, expected_predictions):
        prediction = json.loads(
            (tmp_path / image_name).with_suffix(".json").read_text()
        )
        assert prediction["text"] == expected_prediction["text"]

        if language_model_weight > 0:
            assert (
                prediction["language_model"]["text"]
                == expected_prediction["language_model"]["text"]
            )
            assert np.isclose(
                prediction["language_model"]["confidence"],
                expected_prediction["language_model"]["confidence"],
            )