Skip to content
Snippets Groups Projects
test_attention.py 4.12 KiB
# Copyright Teklia (contact@teklia.com) & Denis Coquenet
# This code is licensed under CeCILL-C

import pytest

from dan.ocr.predict.attention import (
    Level,
    parse_delimiters,
    split_text,
    split_text_and_confidences,
)
from dan.utils import EntityType, parse_charset_pattern, parse_tokens_pattern


@pytest.mark.parametrize(
    (
        "text",
        "confidence",
        "level",
        "tokens",
        "expected_split_text",
        "expected_mean_confidences",
        "expected_offsets",
    ),
    [
        # level: char
        (
            "To <kyo>",
            [0.1, 0.2, 0.3, 0.4],
            Level.Char,
            None,
            ["T", "o", " ", "<kyo>"],
            [0.1, 0.2, 0.3, 0.4],
            [0, 0, 0, 0],
        ),
        # level: word
        (
            "Lo ve\nTokyo",
            [0.1, 0.1, 0.2, 0.3, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5],
            Level.Word,
            None,
            ["Lo", "ve", "Tokyo"],
            [0.1, 0.3, 0.5],
            [1, 1, 0],
        ),
        # level: line
        (
            "Love\nTokyo",
            [0.1, 0.1, 0.1, 0.1, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3],
            Level.Line,
            None,
            ["Love", "Tokyo"],
            [0.1, 0.3],
            [1, 0],
        ),
        # level: NER (no end tokens)
        (
            "ⒶLove ⒷTokyo",
            [0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5],
            Level.NER,
            [EntityType(start=""), EntityType(start="")],
            ["ⒶLove ", "ⒷTokyo"],
            [0.2, 0.48],
            [0, 0],
        ),
        # level: NER (with end tokens)
        (
            "ⓐLoveⒶ ⓑTokyoⒷ",
            [0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.6, 0.6, 0.6, 0.6, 0.6, 0.7],
            Level.NER,
            [EntityType(start="", end=""), EntityType(start="", end="")],
            ["ⓐLoveⒶ", "ⓑTokyoⒷ"],
            [0.2, 0.6],
            [1, 0],
        ),
        # level: NER (no end tokens, no space)
        (
            "ⒶLoveⒷTokyo",
            [0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.4, 0.4, 0.4, 0.4],
            Level.NER,
            [EntityType(start=""), EntityType(start="")],
            ["ⒶLove", "ⒷTokyo"],
            [0.18, 0.38],
            [0, 0],
        ),
        # level: NER (with end tokens, no space)
        (
            "ⓐLoveⒶⓑTokyoⒷ",
            [0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5, 0.6],
            Level.NER,
            [EntityType(start="", end=""), EntityType(start="", end="")],
            ["ⓐLoveⒶ", "ⓑTokyoⒷ"],
            [0.2, 0.5],
            [0, 0],
        ),
    ],
)
def test_split_text_and_confidences(
    text: str,
    confidence: list[float],
    level: Level,
    tokens: list[EntityType] | None,
    expected_split_text: list[str],
    expected_mean_confidences: list[list[float]],
    expected_offsets: list[int],
):
    # Full charset
    charset = [
        # alphabet
        "T",
        "o",
        "L",
        "v",
        "e",
        "k",
        "y",
        # Entities
        "",
        "",
        "",
        "",
        # Special
        "<kyo>",
        # Punctuation
        " ",
    ]
    texts_conf, averages_conf, offsets_conf = split_text_and_confidences(
        text=text,
        confidences=confidence,
        level=level,
        char_separators=parse_charset_pattern(charset),
        word_separators=parse_delimiters([" ", "\n"]),
        line_separators=parse_delimiters(["\n"]),
        tokens_separators=parse_tokens_pattern(tokens) if tokens else None,
    )
    texts, offsets = split_text(
        text=text,
        level=level,
        char_separators=parse_charset_pattern(charset),
        word_separators=parse_delimiters([" ", "\n"]),
        line_separators=parse_delimiters(["\n"]),
        tokens_separators=parse_tokens_pattern(tokens) if tokens else None,
    )

    assert texts == expected_split_text
    assert offsets == expected_offsets
    assert texts_conf == expected_split_text
    assert averages_conf == expected_mean_confidences
    assert offsets_conf == expected_offsets