Skip to content
Snippets Groups Projects
test_attention.py 3.15 KiB
Newer Older
import pytest

from dan.ocr.predict.attention import (
    Level,
    parse_delimiters,
    split_text_and_confidences,
)
from dan.utils import EntityType, parse_tokens_pattern


@pytest.mark.parametrize(
    (
        "text",
        "confidence",
        "level",
        "tokens",
        "split_text",
        "mean_confidences",
        "expected_offsets",
    ),
    [
        # level: char
        (
            "Tok",
            [0.1, 0.2, 0.3],
            Level.Char,
            None,
            ["T", "o", "k"],
            [0.1, 0.2, 0.3],
            [0, 0, 0],
        ),
        # level: word
        (
            "Lo ve\nTokyo",
            [0.1, 0.1, 0.2, 0.3, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5],
            Level.Word,
            None,
            ["Lo", "ve", "Tokyo"],
            [0.1, 0.3, 0.5],
            [1, 1, 0],
        ),
        # level: line
        (
            "Love\nTokyo",
            [0.1, 0.1, 0.1, 0.1, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3],
            Level.Line,
            None,
            ["Love", "Tokyo"],
            [0.1, 0.3],
            [1, 0],
        ),
        # level: NER (no end tokens)
        (
            "ⒶLove ⒷTokyo",
            [0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5],
            Level.NER,
            [EntityType(start=""), EntityType(start="")],
            ["ⒶLove ", "ⒷTokyo"],
            [0.2, 0.48],
            [0, 0],
        ),
        # level: NER (with end tokens)
        (
            "ⓐLoveⒶ ⓑTokyoⒷ",
            [0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.6, 0.6, 0.6, 0.6, 0.6, 0.7],
            Level.NER,
            [EntityType(start="", end=""), EntityType(start="", end="")],
            ["ⓐLoveⒶ", "ⓑTokyoⒷ"],
            [0.2, 0.6],
            [1, 0],
        ),
        # level: NER (no end tokens, no space)
        (
            "ⒶLoveⒷTokyo",
            [0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.4, 0.4, 0.4, 0.4],
            Level.NER,
            [EntityType(start=""), EntityType(start="")],
            ["ⒶLove", "ⒷTokyo"],
            [0.18, 0.38],
            [0, 0],
        ),
        # level: NER (with end tokens, no space)
        (
            "ⓐLoveⒶⓑTokyoⒷ",
            [0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5, 0.6],
            Level.NER,
            [EntityType(start="", end=""), EntityType(start="", end="")],
            ["ⓐLoveⒶ", "ⓑTokyoⒷ"],
            [0.2, 0.5],
            [0, 0],
        ),
    ],
)
def test_split_text_and_confidences(
    text: str,
    confidence: list[float],
    level: Level,
    tokens: list[EntityType] | None,
    split_text: list[str],
    mean_confidences: list[list[float]],
    expected_offsets: list[int],
):
    texts, averages, offsets = split_text_and_confidences(
        text=text,
        confidences=confidence,
        level=level,
        word_separators=parse_delimiters([" ", "\n"]),
        line_separators=parse_delimiters(["\n"]),
        tokens_separators=parse_tokens_pattern(tokens) if tokens else None,
    )

    assert texts == split_text
    assert averages == mean_confidences
    assert offsets == expected_offsets