# -*- coding: utf-8 -*-

import json
import logging
import pickle
import re
from operator import attrgetter, methodcaller
from pathlib import Path
from typing import NamedTuple
from unittest.mock import patch

import pytest
from PIL import Image, ImageChops

from arkindex_export import Element, Transcription
from dan.datasets.extract.arkindex import IIIF_FULL_SIZE, ArkindexExtractor
from dan.datasets.extract.exceptions import (
    NoEndTokenError,
    NoTranscriptionError,
    UnknownTokenInText,
)
from dan.datasets.extract.utils import (
    EntityType,
    download_image,
    insert_token,
    normalize_linebreaks,
    normalize_spaces,
)
from dan.utils import parse_tokens
from line_image_extractor.image_utils import BoundingBox, polygon_to_bbox
from tests import FIXTURES

EXTRACTION_DATA_PATH = FIXTURES / "extraction"

TWO_SPACES_REGEX = re.compile(r" {2}")
ENTITY_TOKEN_SPACE = re.compile(r"[ⓢ|ⓕ|ⓑ] ")
TWO_SPACES_LM_REGEX = re.compile(r"▁ ▁")

# NamedTuple to mock actual database result
Entity = NamedTuple("Entity", offset=int, length=int, type=str, value=str)

TOKENS = {
    "P": EntityType(start="ⓟ", end="Ⓟ"),
    "D": EntityType(start="ⓓ", end="Ⓓ"),
    "N": EntityType(start="ⓝ", end="Ⓝ"),
    "I": EntityType(start="ⓘ", end="Ⓘ"),
}


def filter_tokens(keys):
    return {key: value for key, value in TOKENS.items() if key in keys}


@pytest.mark.parametrize(
    "max_width, max_height, width, height, resize",
    (
        (1000, 2000, 900, 800, IIIF_FULL_SIZE),
        (1000, 2000, 1100, 800, "1000,"),
        (1000, 2000, 1100, 2800, ",2000"),
        (1000, 2000, 2000, 3000, "1000,"),
    ),
)
def test_get_iiif_size_arg(max_width, max_height, width, height, resize):
    assert (
        ArkindexExtractor(max_width=max_width, max_height=max_height).get_iiif_size_arg(
            width=width, height=height
        )
        == resize
    )


@pytest.mark.parametrize(
    "text,offset,length,expected",
    (
        ("n°1 16 janvier 1611", 0, 3, "ⓘn°1Ⓘ"),
        ("ⓘn°1Ⓘ 16 janvier 1611", 6, 15, "ⓘ16 janvier 1611Ⓘ"),
    ),
)
def test_insert_token(text, offset, length, expected):
    assert (
        insert_token(text, EntityType(start="ⓘ", end="Ⓘ"), offset, length) == expected
    )


def test_reconstruct_text_no_end_token_error():
    arkindex_extractor = ArkindexExtractor(entity_separators=[])
    arkindex_extractor.tokens = {
        "X": EntityType(start="ⓧ"),
    }
    with pytest.raises(NoEndTokenError, match="Label `X` has no end token."):
        arkindex_extractor.reconstruct_text(
            "n°1 x 16 janvier 1611",
            [
                Entity(
                    offset=0,
                    length=3,
                    type="X",
                    value="n°1",
                ),
            ],
        )


@pytest.mark.parametrize(
    "entity_separators,tokens,expected",
    (
        # Whole text...
        # ... + All tokens
        ([], TOKENS, "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ\nⓝMichelⓃ"),
        # ... + 1rst and 2nd tokens
        ([], filter_tokens(["P", "D"]), "ⓟn°1Ⓟ x ⓓ16 janvier 1611Ⓓ\nMichel"),
        # ... + 1rst and 3rd tokens
        ([], filter_tokens(["P", "N"]), "ⓟn°1Ⓟ x 16 janvier 1611\nⓝMichelⓃ"),
        # ... + 2nd and 3rd tokens
        ([], filter_tokens(["D", "N"]), "n°1 x ⓓ16 janvier 1611Ⓓ\nⓝMichelⓃ"),
        # Only entities...
        # ... + All tokens
        (["\n", " "], TOKENS, "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ\nⓝMichelⓃ"),
        # ... + 1rst and 2nd tokens
        (["\n", " "], filter_tokens(["P", "D"]), "ⓟn°1Ⓟ ⓓ16 janvier 1611Ⓓ"),
        # ... + 1rst and 3rd tokens
        (["\n", " "], filter_tokens(["P", "N"]), "ⓟn°1Ⓟ\nⓝMichelⓃ"),
        # ... + 2nd and 3rd tokens
        (["\n", " "], filter_tokens(["D", "N"]), "ⓓ16 janvier 1611Ⓓ\nⓝMichelⓃ"),
    ),
)
@pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text(entity_separators, tokens, expected, text_before, text_after):
    arkindex_extractor = ArkindexExtractor(entity_separators=entity_separators)
    arkindex_extractor.tokens = tokens
    assert arkindex_extractor.reconstruct_text(
        text_before + "n°1 x 16 janvier 1611\nMichel" + text_after,
        [
            Entity(
                offset=0 + len(text_before),
                length=3,
                type="P",
                value="n°1",
            ),
            Entity(
                offset=6 + len(text_before),
                length=15,
                type="D",
                value="16 janvier 1611",
            ),
            Entity(
                offset=22 + len(text_before),
                length=6,
                type="N",
                value="Michel",
            ),
        ],
    ) == (
        (text_before if not entity_separators else "")
        + expected
        + (text_after if not entity_separators else "")
    )


@pytest.mark.parametrize(
    "text,trimmed",
    (
        ("no_spaces", "no_spaces"),
        (" beginning", "beginning"),
        ("ending ", "ending"),
        (" both ", "both"),
        ("    consecutive", "consecutive"),
        ("\ttab", "tab"),
        ("\t tab", "tab"),
        (" \ttab", "tab"),
        ("no|space", "no|space"),
    ),
)
def test_normalize_spaces(text, trimmed):
    assert normalize_spaces(text) == trimmed


@pytest.mark.parametrize(
    "text,trimmed",
    (
        ("no_linebreaks", "no_linebreaks"),
        ("\nbeginning", "beginning"),
        ("ending\n", "ending"),
        ("\nboth\n", "both"),
        ("\n\n\nconsecutive", "consecutive"),
        ("\rcarriage_return", "carriage_return"),
        ("\r\ncarriage_return+linebreak", "carriage_return+linebreak"),
        ("\n\r\r\n\ncarriage_return+linebreak", "carriage_return+linebreak"),
        ("no|linebreaks", "no|linebreaks"),
    ),
)
def test_normalize_linebreaks(text, trimmed):
    assert normalize_linebreaks(text) == trimmed


@pytest.mark.parametrize(
    "entity_separators",
    (
        # Whole text
        [],
        # Only entities
        [" ", "\n"],
    ),
)
@pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text_joined_entities(entity_separators, text_before, text_after):
    arkindex_extractor = ArkindexExtractor(entity_separators=entity_separators)
    arkindex_extractor.tokens = TOKENS
    assert arkindex_extractor.reconstruct_text(
        text_before + "LouisXIV" + text_after,
        [
            Entity(
                offset=0 + len(text_before),
                length=5,
                type="P",
                value="Louis",
            ),
            Entity(
                offset=5 + len(text_before),
                length=3,
                type="I",
                value="XIV",
            ),
        ],
    ) == (
        (text_before if not entity_separators else "")
        + "ⓟLouisⓅⓘXIVⒾ"
        + (text_after if not entity_separators else "")
    )


@pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text_several_separators(text_before, text_after):
    arkindex_extractor = ArkindexExtractor(entity_separators=["\n", " "])
    arkindex_extractor.tokens = TOKENS
    # Keep "\n" instead of " "
    assert (
        arkindex_extractor.reconstruct_text(
            text_before + "King\nLouis XIV" + text_after,
            [
                Entity(
                    offset=0 + len(text_before),
                    length=4,
                    type="D",
                    value="King",
                ),
                Entity(
                    offset=11 + len(text_before),
                    length=3,
                    type="I",
                    value="XIV",
                ),
            ],
        )
        == "ⓓKingⒹ\nⓘXIVⒾ"
    )


@pytest.mark.parametrize("joined", (True, False))
@pytest.mark.parametrize("text_before", ("", "text before "))
@pytest.mark.parametrize("text_after", ("", " text after"))
def test_reconstruct_text_only_start_token(joined, text_before, text_after):
    separator = " " if not joined else ""

    arkindex_extractor = ArkindexExtractor(entity_separators=[" ", "\n"])
    arkindex_extractor.tokens = {
        "P": EntityType(start="ⓟ"),
        "I": EntityType(start="ⓘ"),
    }
    assert (
        arkindex_extractor.reconstruct_text(
            text_before + "Louis" + separator + "XIV" + text_after,
            [
                Entity(
                    offset=0 + len(text_before),
                    length=5,
                    type="P",
                    value="Louis",
                ),
                Entity(
                    offset=5 + len(separator) + len(text_before),
                    length=3,
                    type="I",
                    value="XIV",
                ),
            ],
        )
        == "ⓟLouis" + separator + "ⓘXIV"
    )


def test_process_element_unknown_token_in_text_error(mock_database, tmp_path):
    output = tmp_path / "extraction"
    arkindex_extractor = ArkindexExtractor(output=output)

    # Create an element with an invalid transcription
    element = Element.create(
        id="element_id",
        name="1",
        type="page",
        polygon="[]",
        created=0.0,
        updated=0.0,
    )
    Transcription.create(
        id="transcription_id",
        text="Is this text valid⁇",
        element=element,
    )

    with pytest.raises(
        UnknownTokenInText,
        match=re.escape(
            "Unknown token found in the transcription text of element (element_id)"
        ),
    ):
        arkindex_extractor.process_element(element, "val")


@pytest.mark.parametrize(
    "load_entities,keep_spaces,transcription_entities_worker_version,expected_subword_language_corpus,subword_vocab_size",
    (
        (
            True,
            True,
            "worker_version_id",
            """▁ ⓢ c a i l l e t ▁ ⓕ m a u r i c e ▁ ⓑ 28. 9.0 6
▁ ⓢ re b ou l ▁ ⓕ j e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ b a re y re ▁ ⓕ j e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ r ou s s y ▁ ⓕ j e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ m a r i n ▁ ⓕ m a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ a m i c a l ▁ ⓕ e l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ b i r o s ▁ ⓕ m a e l ▁ ⓑ 30. 1 0 . 1 0""",
            40,
        ),
        (
            True,
            False,
            "worker_version_id",
            """▁ ⓢ c a i l l e t ▁ ⓕ m a u r i c e ▁ ⓑ 28. 9.0 6
▁ ⓢ re b ou l ▁ ⓕ j e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ b a re y re ▁ ⓕ j e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ r ou s s y ▁ ⓕ j e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ m a r i n ▁ ⓕ m a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ a m i c a l ▁ ⓕ e l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ b i r o s ▁ ⓕ m a e l ▁ ⓑ 30. 1 0 . 1 0""",
            40,
        ),
        (
            False,
            True,
            "worker_version_id",
            """▁ ca i l l e t ▁ ma u r i ce ▁ 28. 9.0 6
▁ re b o u l ▁ j e a n ▁ 30. 9.0 2
▁ b a re y re ▁ j e a n ▁ 28. 3 . 1 1
▁ r o u s s y ▁ j e a n ▁ 4 . 11.1 4
▁ ma r i n ▁ ma r ce l ▁ 10. 8 . 0 6
▁ a m i ca l ▁ el o i ▁ 11.1 0 . 0 4
▁ b i r o s ▁ ma el ▁ 30. 10. 1 0""",
            40,
        ),
        (
            False,
            False,
            "worker_version_id",
            """▁ ca i l l e t ▁ ma u r i ce ▁ 28. 9.0 6
▁ re b o u l ▁ j e a n ▁ 30. 9.0 2
▁ b a re y re ▁ j e a n ▁ 28. 3 . 1 1
▁ r o u s s y ▁ j e a n ▁ 4 . 11.1 4
▁ ma r i n ▁ ma r ce l ▁ 10. 8 . 0 6
▁ a m i ca l ▁ el o i ▁ 11.1 0 . 0 4
▁ b i r o s ▁ ma el ▁ 30. 10. 1 0""",
            40,
        ),
        (
            True,
            True,
            False,
            """▁ ⓢ C a i l l e t ▁ ⓕ M a u r i c e ▁ ⓑ 2 8 . 9 . 0 6
▁ ⓢ R e b o u l ▁ ⓕ J e a n ▁ ⓑ 3 0 . 9 . 0 2
▁ ⓢ B a r e y r e ▁ ⓕ J e a n ▁ ⓑ 2 8 . 3 . 1 1
▁ ⓢ R o u s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ M a r i n ▁ ⓕ M a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ A m i c a l ▁ ⓕ E l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 3 0 . 1 0 . 1 0""",
            40,
        ),
        (
            True,
            True,
            False,
            """▁ ⓢ C a i l l e t ▁ ⓕ M a u ri ce ▁ ⓑ 28. 9.0 6
▁ ⓢ R e b ou l ▁ ⓕ J e a n ▁ ⓑ 30. 9.0 2
▁ ⓢ B a re y re ▁ ⓕ J e a n ▁ ⓑ 28. 3 . 1 1
▁ ⓢ R ou s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 11.1 4
▁ ⓢ Mar i n ▁ ⓕ Mar ce l ▁ ⓑ 10. 8 . 0 6
▁ ⓢ A m ic a l ▁ ⓕ E l o i ▁ ⓑ 11.1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 30. 10. 10""",
            55,
        ),
        (
            True,
            False,
            False,
            """▁ ⓢ C a i l l e t ▁ ⓕ M a u r i c e ▁ ⓑ 2 8 . 9 . 0 6
▁ ⓢ R e b o u l ▁ ⓕ J e a n ▁ ⓑ 3 0 . 9 . 0 2
▁ ⓢ B a r e y r e ▁ ⓕ J e a n ▁ ⓑ 2 8 . 3 . 1 1
▁ ⓢ R o u s s y ▁ ⓕ J e a n ▁ ⓑ 4 . 1 1 . 1 4
▁ ⓢ M a r i n ▁ ⓕ M a r c e l ▁ ⓑ 1 0 . 8 . 0 6
▁ ⓢ A m i c a l ▁ ⓕ E l o i ▁ ⓑ 1 1 . 1 0 . 0 4
▁ ⓢ B i r o s ▁ ⓕ M a e l ▁ ⓑ 3 0 . 1 0 . 1 0""",
            40,
        ),
        (
            False,
            True,
            False,
            """▁ C a i l l e t ▁ Ma u r i c e ▁ 28. 9.0 6
▁ R e b o u l ▁ J e a n ▁ 30. 9.0 2
▁ B a r e y r e ▁ J e a n ▁ 28. 3 . 1 1
▁ R o u s s y ▁ J e a n ▁ 4 . 1 1 . 1 4
▁ Ma r i n ▁ Ma r c e l ▁ 1 0 . 8 . 0 6
▁ A m i c a l ▁ E l o i ▁ 1 1 . 1 0 . 0 4
▁ B i r o s ▁ Ma e l ▁ 30. 1 0 . 1 0""",
            40,
        ),
        (
            False,
            False,
            False,
            """▁ C a i l l e t ▁ Ma u r i c e ▁ 28. 9.0 6
▁ R e b o u l ▁ J e a n ▁ 30. 9.0 2
▁ B a r e y r e ▁ J e a n ▁ 28. 3 . 1 1
▁ R o u s s y ▁ J e a n ▁ 4 . 1 1 . 1 4
▁ Ma r i n ▁ Ma r c e l ▁ 1 0 . 8 . 0 6
▁ A m i c a l ▁ E l o i ▁ 1 1 . 1 0 . 0 4
▁ B i r o s ▁ Ma e l ▁ 30. 1 0 . 1 0""",
            40,
        ),
    ),
)
@patch("dan.datasets.extract.arkindex.download_image")
def test_extract(
    mock_download_image,
    load_entities,
    keep_spaces,
    transcription_entities_worker_version,
    mock_database,
    expected_subword_language_corpus,
    subword_vocab_size,
    tmp_path,
):
    output = tmp_path / "extraction"
    output.mkdir(parents=True, exist_ok=True)
    (output / "language_model").mkdir(parents=True, exist_ok=True)
    tokens_path = EXTRACTION_DATA_PATH / "tokens.yml"
    tokens = [
        token
        for entity_type in parse_tokens(tokens_path).values()
        for token in [entity_type.start, entity_type.end]
        if token
    ]

    def mock_build_image_url(image_url, polygon, *args, **kwargs):
        # During tests, the image URL is its local path
        return polygon_to_bbox(json.loads(str(polygon))), image_url

    extractor = ArkindexExtractor(
        folders=["train", "val", "test"],
        element_type=["text_line"],
        parent_element_type="double_page",
        output=output,
        entity_separators=[" "] if load_entities else None,
        tokens=tokens_path if load_entities else None,
        transcription_worker_version=transcription_entities_worker_version,
        entity_worker_version=transcription_entities_worker_version
        if load_entities
        else None,
        keep_spaces=keep_spaces,
        image_extension=".jpg",
        subword_vocab_size=subword_vocab_size,
    )
    # Mock build_image_url to simply return the path to the image
    extractor.build_iiif_url = mock_build_image_url
    # Mock download_image so that it simply opens it with Pillow
    mock_download_image.side_effect = Image.open
    extractor.run()

    # Check files
    IMAGE_DIR = output / "images"
    TEST_DIR = IMAGE_DIR / "test"
    TRAIN_DIR = IMAGE_DIR / "train"
    VAL_DIR = IMAGE_DIR / "val"

    expected_paths = [
        output / "charset.pkl",
        # Images of test folder
        TEST_DIR / "test-page_1-line_1.jpg",
        TEST_DIR / "test-page_1-line_2.jpg",
        TEST_DIR / "test-page_1-line_3.jpg",
        TEST_DIR / "test-page_2-line_1.jpg",
        TEST_DIR / "test-page_2-line_2.jpg",
        TEST_DIR / "test-page_2-line_3.jpg",
        # Images of train folder
        TRAIN_DIR / "train-page_1-line_1.jpg",
        TRAIN_DIR / "train-page_1-line_2.jpg",
        TRAIN_DIR / "train-page_1-line_3.jpg",
        TRAIN_DIR / "train-page_1-line_4.jpg",
        TRAIN_DIR / "train-page_2-line_1.jpg",
        TRAIN_DIR / "train-page_2-line_2.jpg",
        TRAIN_DIR / "train-page_2-line_3.jpg",
        # Images of val folder
        VAL_DIR / "val-page_1-line_1.jpg",
        VAL_DIR / "val-page_1-line_2.jpg",
        VAL_DIR / "val-page_1-line_3.jpg",
        output / "labels.json",
        # Language resources
        output / "language_model" / "corpus_characters.txt",
        output / "language_model" / "corpus_subwords.txt",
        output / "language_model" / "corpus_words.txt",
        output / "language_model" / "lexicon_characters.txt",
        output / "language_model" / "lexicon_subwords.txt",
        output / "language_model" / "lexicon_words.txt",
        output / "language_model" / "subword_tokenizer.model",
        output / "language_model" / "subword_tokenizer.vocab",
        output / "language_model" / "tokens.txt",
    ]
    assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths

    # Check "labels.json"
    expected_labels = {
        "test": {
            str(TEST_DIR / "test-page_1-line_1.jpg"): "ⓢCou⁇e⁇  ⓕBouis  ⓑ⁇.12.14",
            str(TEST_DIR / "test-page_1-line_2.jpg"): "ⓢ⁇outrain  ⓕA⁇ol⁇⁇e  ⓑ9.4.13",
            str(TEST_DIR / "test-page_1-line_3.jpg"): "ⓢ⁇abale  ⓕ⁇ran⁇ais  ⓑ26.3.11",
            str(TEST_DIR / "test-page_2-line_1.jpg"): "ⓢ⁇urosoy  ⓕBouis  ⓑ22⁇4⁇18",
            str(TEST_DIR / "test-page_2-line_2.jpg"): "ⓢColaiani  ⓕAn⁇els  ⓑ28.11.1⁇",
            str(TEST_DIR / "test-page_2-line_3.jpg"): "ⓢRenouar⁇  ⓕMaurice  ⓑ2⁇.⁇.04",
        },
        "train": {
            str(TRAIN_DIR / "train-page_1-line_1.jpg"): "ⓢCaillet  ⓕMaurice  ⓑ28.9.06",
            str(TRAIN_DIR / "train-page_1-line_2.jpg"): "ⓢReboul  ⓕJean  ⓑ30.9.02",
            str(TRAIN_DIR / "train-page_1-line_3.jpg"): "ⓢBareyre  ⓕJean  ⓑ28.3.11",
            str(TRAIN_DIR / "train-page_1-line_4.jpg"): "ⓢRoussy  ⓕJean  ⓑ4.11.14",
            str(TRAIN_DIR / "train-page_2-line_1.jpg"): "ⓢMarin  ⓕMarcel  ⓑ10.8.06",
            str(TRAIN_DIR / "train-page_2-line_2.jpg"): "ⓢAmical  ⓕEloi  ⓑ11.10.04",
            str(TRAIN_DIR / "train-page_2-line_3.jpg"): "ⓢBiros  ⓕMael  ⓑ30.10.10",
        },
        "val": {
            str(VAL_DIR / "val-page_1-line_1.jpg"): "ⓢMonar⁇  ⓕBouis  ⓑ29⁇⁇⁇04",
            str(VAL_DIR / "val-page_1-line_2.jpg"): "ⓢAstier  ⓕArt⁇ur  ⓑ11⁇2⁇13",
            str(VAL_DIR / "val-page_1-line_3.jpg"): "ⓢ⁇e ⁇lie⁇er  ⓕJules  ⓑ21⁇11⁇11",
        },
    }

    # Transcriptions with worker version are in lowercase
    if transcription_entities_worker_version:
        for split in expected_labels:
            for path in expected_labels[split]:
                expected_labels[split][path] = expected_labels[split][path].lower()

    # If we do not load entities, remove tokens
    if not load_entities:
        token_translations = {ord(token): None for token in tokens}
        for split in expected_labels:
            for path in expected_labels[split]:
                expected_labels[split][path] = expected_labels[split][path].translate(
                    token_translations
                )

    # Replace double spaces with regular space
    if not keep_spaces:
        for split in expected_labels:
            for path in expected_labels[split]:
                expected_labels[split][path] = TWO_SPACES_REGEX.sub(
                    " ", expected_labels[split][path]
                )

    assert json.loads((output / "labels.json").read_text()) == expected_labels

    # Check "charset.pkl"
    expected_charset = set()
    for label in expected_labels["train"].values():
        expected_charset.update(set(label))

    if load_entities:
        expected_charset.update(tokens)
    expected_charset.add("⁇")
    assert set(pickle.loads((output / "charset.pkl").read_bytes())) == expected_charset

    # Check "language_corpus.txt"
    expected_char_language_corpus = """ⓢ C a i l l e t ▁ ▁ ⓕ M a u r i c e ▁ ▁ ⓑ 2 8 . 9 . 0 6
ⓢ R e b o u l ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 3 0 . 9 . 0 2
ⓢ B a r e y r e ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 2 8 . 3 . 1 1
ⓢ R o u s s y ▁ ▁ ⓕ J e a n ▁ ▁ ⓑ 4 . 1 1 . 1 4
ⓢ M a r i n ▁ ▁ ⓕ M a r c e l ▁ ▁ ⓑ 1 0 . 8 . 0 6
ⓢ A m i c a l ▁ ▁ ⓕ E l o i ▁ ▁ ⓑ 1 1 . 1 0 . 0 4
ⓢ B i r o s ▁ ▁ ⓕ M a e l ▁ ▁ ⓑ 3 0 . 1 0 . 1 0"""

    expected_word_language_corpus = """ⓢ Caillet ▁ ⓕ Maurice ▁ ⓑ 28 ▁ . ▁ 9 ▁ . ▁ 06
ⓢ Reboul ▁ ⓕ Jean ▁ ⓑ 30 ▁ . ▁ 9 ▁ . ▁ 02
ⓢ Bareyre ▁ ⓕ Jean ▁ ⓑ 28 ▁ . ▁ 3 ▁ . ▁ 11
ⓢ Roussy ▁ ⓕ Jean ▁ ⓑ 4 ▁ . ▁ 11 ▁ . ▁ 14
ⓢ Marin ▁ ⓕ Marcel ▁ ⓑ 10 ▁ . ▁ 8 ▁ . ▁ 06
ⓢ Amical ▁ ⓕ Eloi ▁ ⓑ 11 ▁ . ▁ 10 ▁ . ▁ 04
ⓢ Biros ▁ ⓕ Mael ▁ ⓑ 30 ▁ . ▁ 10 ▁ . ▁ 10"""

    # Transcriptions with worker version are in lowercase
    if transcription_entities_worker_version:
        expected_char_language_corpus = expected_char_language_corpus.lower()
        expected_word_language_corpus = expected_word_language_corpus.lower()
        expected_subword_language_corpus = expected_subword_language_corpus.lower()

    # If we do not load entities, remove tokens
    if not load_entities:
        token_translations = {f"{token} ": "" for token in tokens}
        expected_char_language_corpus = ENTITY_TOKEN_SPACE.sub(
            "", expected_char_language_corpus
        )
        expected_word_language_corpus = ENTITY_TOKEN_SPACE.sub(
            "", expected_word_language_corpus
        )
        expected_subword_language_corpus = ENTITY_TOKEN_SPACE.sub(
            "", expected_subword_language_corpus
        )
    # Replace double spaces with regular space
    if not keep_spaces:
        expected_char_language_corpus = TWO_SPACES_LM_REGEX.sub(
            "▁", expected_char_language_corpus
        )
        expected_word_language_corpus = TWO_SPACES_LM_REGEX.sub(
            "▁", expected_word_language_corpus
        )
        expected_subword_language_corpus = TWO_SPACES_LM_REGEX.sub(
            "▁", expected_subword_language_corpus
        )

    assert (
        output / "language_model" / "corpus_characters.txt"
    ).read_text() == expected_char_language_corpus

    assert (
        output / "language_model" / "corpus_words.txt"
    ).read_text() == expected_word_language_corpus

    assert (
        output / "language_model" / "corpus_subwords.txt"
    ).read_text() == expected_subword_language_corpus

    # Check "language_tokens.txt"
    expected_language_tokens = [
        "▁" if t.isspace() else t for t in sorted(list(expected_charset))
    ]
    expected_language_tokens.append("◌")
    assert (output / "language_model" / "tokens.txt").read_text() == "\n".join(
        expected_language_tokens
    )

    # Check "language_lexicon.txt"
    expected_language_char_lexicon = [f"{t} {t}" for t in expected_language_tokens]
    assert (
        output / "language_model" / "lexicon_characters.txt"
    ).read_text() == "\n".join(expected_language_char_lexicon)

    word_vocab = set([word for word in expected_word_language_corpus.split()])
    expected_language_word_lexicon = [
        f"{word} {' '.join(word)}" for word in sorted(word_vocab)
    ]
    assert (output / "language_model" / "lexicon_words.txt").read_text() == "\n".join(
        expected_language_word_lexicon
    )

    subword_vocab = set(
        [subword for subword in expected_subword_language_corpus.split()]
    )
    expected_language_subword_lexicon = [
        f"{subword} {' '.join(subword)}" for subword in sorted(subword_vocab)
    ]
    assert (
        output / "language_model" / "lexicon_subwords.txt"
    ).read_text() == "\n".join(expected_language_subword_lexicon)

    # Check cropped images
    for expected_path in expected_paths:
        if expected_path.suffix != ".jpg":
            continue

        assert ImageChops.difference(
            Image.open(
                EXTRACTION_DATA_PATH / "images" / "text_line" / expected_path.name
            ),
            Image.open(expected_path),
        )


@patch("dan.datasets.extract.arkindex.ArkindexExtractor.build_iiif_url")
def test_download_image_error(iiif_url, caplog, capsys):
    task = {
        "split": "train",
        "polygon": [],
        "image_url": "deadbeef",
        "destination": Path("/dev/null"),
    }
    # Make download_image crash
    iiif_url.return_value = BoundingBox(0, 0, 0, 0), task["image_url"]

    extractor = ArkindexExtractor(
        folders=["train", "val", "test"],
        element_type=["text_line"],
        parent_element_type="double_page",
        output=None,
        entity_separators=None,
        tokens=None,
        transcription_worker_version=None,
        entity_worker_version=None,
        keep_spaces=False,
        image_extension=".jpg",
    )

    # Build a random task
    extractor.tasks = [task]

    # Add the key in data
    extractor.data[task["split"]][str(task["destination"])] = "deadbeefdata"

    extractor.download_images()

    # Key should have been removed
    assert task["destination"] not in extractor.data[task["split"]]

    # Check error log
    assert len(caplog.record_tuples) == 1
    _, level, msg = caplog.record_tuples[0]
    assert level == logging.ERROR
    assert msg == "Failed to download 1 image(s)."

    # Check stdout
    captured = capsys.readouterr()
    assert captured.out == "deadbeef: Image URL must be HTTP(S) for element null\n"


def test_download_image_error_try_max(responses, caplog):
    # An image's URL
    url = (
        "https://blabla.com/iiif/2/image_path.jpg/231,699,2789,3659/full/0/default.jpg"
    )
    fixed_url = (
        "https://blabla.com/iiif/2/image_path.jpg/231,699,2789,3659/max/0/default.jpg"
    )

    # Fake responses error
    responses.add(
        responses.GET,
        url,
        status=400,
    )
    # Correct response with max
    responses.add(
        responses.GET,
        fixed_url,
        status=200,
        body=next((FIXTURES / "prediction" / "images").iterdir()).read_bytes(),
    )

    image = download_image(url)

    assert image
    # We try 3 times with the first URL
    # Then the first try with the new URL is successful
    assert len(responses.calls) == 4
    assert list(map(attrgetter("request.url"), responses.calls)) == [url] * 3 + [
        fixed_url
    ]

    # Check error log
    assert len(caplog.record_tuples) == 2

    # We should only have WARNING levels
    assert set(level for _, level, _ in caplog.record_tuples) == {logging.WARNING}


@pytest.mark.parametrize("allow_empty", (True, False))
def test_empty_transcription(allow_empty, mock_database):
    extractor = ArkindexExtractor(
        folders=["train", "val", "test"],
        element_type=["text_line"],
        parent_element_type="double_page",
        output=None,
        entity_separators=None,
        tokens=None,
        transcription_worker_version=None,
        entity_worker_version=None,
        keep_spaces=False,
        image_extension=".jpg",
        allow_empty=allow_empty,
    )
    element_no_transcription = Element(id="unknown")
    if allow_empty:
        assert extractor.extract_transcription(element_no_transcription) == ""
    else:
        with pytest.raises(NoTranscriptionError):
            extractor.extract_transcription(element_no_transcription)