# -*- coding: utf-8 -*-
import json
import uuid
from operator import itemgetter
from typing import List, Optional, Union

import pytest

from arkindex_export import (
    Element,
    ElementPath,
    Entity,
    EntityType,
    Image,
    ImageServer,
    Transcription,
    TranscriptionEntity,
    WorkerRun,
    WorkerVersion,
    database,
)
from dan.ocr.train import update_config
from tests import FIXTURES


@pytest.fixture(scope="session")
def mock_database(tmp_path_factory):
    def create_transcription_entity(
        transcription: Transcription,
        worker_version: Union[str, None],
        type: str,
        name: str,
        offset: int,
    ) -> None:
        entity_type, _ = EntityType.get_or_create(
            name=type, defaults={"id": f"{type}_id"}
        )
        entity = Entity.create(
            id=str(uuid.uuid4()),
            name=name,
            type=entity_type,
            worker_version=worker_version,
        )
        TranscriptionEntity.create(
            entity=entity,
            length=len(name),
            offset=offset,
            transcription=transcription,
            worker_version=worker_version,
        )

    def create_transcriptions(element: Element, entities: List[dict]) -> None:
        if not entities:
            return

        # Add transcription with entities
        entities = sorted(entities, key=itemgetter("offset"))

        # We will add extra spaces to test the "keep_spaces" parameters of the "extract" command
        for offset, entity in enumerate(entities[1:], start=1):
            entity["offset"] += offset

        for worker_version in [None, "worker_version_id"]:
            # Use different transcriptions to filter by worker version
            if worker_version == "worker_version_id":
                for entity in entities:
                    entity["name"] = entity["name"].lower()

            transcription = Transcription.create(
                id=element.id + (worker_version or ""),
                # Add extra spaces to test the "keep_spaces" parameters of the "extract" command
                text="  ".join(map(itemgetter("name"), entities)),
                element=element,
                worker_version=worker_version,
            )

            for entity in entities:
                create_transcription_entity(
                    transcription=transcription,
                    worker_version=worker_version,
                    **entity,
                )

    def create_element(id: str, parent: Optional[Element] = None) -> None:
        element_path = (FIXTURES / "extraction" / "elements" / id).with_suffix(".json")
        element_json = json.loads(element_path.read_text())

        element_type = element_json["type"]
        image_path = (
            FIXTURES / "extraction" / "images" / element_type / id
        ).with_suffix(".jpg")

        polygon = element_json.get("polygon")
        # Always use page images because polygons are based on the full image
        image, _ = (
            Image.get_or_create(
                id=id + "-image",
                defaults={
                    "server": image_server,
                    # Use path to image instead of actual URL since we won't be doing any download
                    "url": image_path,
                    "width": 0,
                    "height": 0,
                },
            )
            if polygon
            else (None, False)
        )

        element = Element.create(
            id=id,
            name=id,
            type=element_type,
            image=image,
            polygon=json.dumps(polygon) if polygon else None,
            created=0.0,
            updated=0.0,
        )

        if parent:
            ElementPath.create(id=str(uuid.uuid4()), parent=parent, child=element)

        create_transcriptions(
            element=element,
            entities=element_json.get("transcription_entities", []),
        )

        # Recursive function to create children
        for child in element_json.get("children", []):
            create_element(id=child, parent=element)

    MODELS = [
        WorkerVersion,
        WorkerRun,
        ImageServer,
        Image,
        Element,
        ElementPath,
        EntityType,
        Entity,
        Transcription,
        TranscriptionEntity,
    ]

    # Initialisation
    tmp_path = tmp_path_factory.mktemp("data")
    database_path = tmp_path / "db.sqlite"
    database.init(
        database_path,
        pragmas={
            # Recommended settings from peewee
            # http://docs.peewee-orm.com/en/latest/peewee/database.html#recommended-settings
            # Do not set journal mode to WAL as it writes in the database
            "cache_size": -1 * 64000,  # 64MB
            "foreign_keys": 1,
            "ignore_check_constraints": 0,
            "synchronous": 0,
        },
    )
    database.connect()

    # Create tables
    database.create_tables(MODELS)

    image_server = ImageServer.create(
        url="http://image/server/url",
        display_name="Image server",
    )

    WorkerVersion.create(
        id="worker_version_id",
        slug="worker_version",
        name="Worker version",
        repository_url="http://repository/url",
        revision="main",
        type="worker",
    )

    # Create folders
    create_element(id="root")

    return database_path


@pytest.fixture
def training_config():
    config = {
        "dataset": {
            "datasets": {
                "training": str(FIXTURES / "training" / "training_dataset"),
            },
            "train": {
                "name": "training-train",
                "datasets": [
                    ("training", "train"),
                ],
            },
            "val": {
                "training-val": [
                    ("training", "val"),
                ],
            },
            "test": {
                "training-test": [
                    ("training", "test"),
                ],
            },
            "max_char_prediction": 30,  # max number of token prediction
            "tokens": None,
        },
        "model": {
            "transfered_charset": True,  # Transfer learning of the decision layer based on charset of the line HTR model
            "additional_tokens": 1,  # for decision layer = [<eot>, ], only for transferred charset
            "encoder": {
                "dropout": 0.5,  # dropout rate for encoder
                "nb_layers": 5,  # encoder
            },
            "h_max": 500,  # maximum height for encoder output (for 2D positional embedding)
            "w_max": 1000,  # maximum width for encoder output (for 2D positional embedding)
            "decoder": {
                "l_max": 15000,  # max predicted sequence (for 1D positional embedding)
                "dec_num_layers": 8,  # number of transformer decoder layers
                "dec_num_heads": 4,  # number of heads in transformer decoder layers
                "dec_res_dropout": 0.1,  # dropout in transformer decoder layers
                "dec_pred_dropout": 0.1,  # dropout rate before decision layer
                "dec_att_dropout": 0.1,  # dropout rate in multi head attention
                "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
                "attention_win": 100,  # length of attention window
                "enc_dim": 256,  # dimension of extracted features
            },
        },
        "training": {
            "data": {
                "batch_size": 2,  # mini-batch size for training
                "load_in_memory": True,  # Load all images in CPU memory
                "worker_per_gpu": 4,  # Num of parallel processes per gpu for data loading
                "preprocessings": [
                    {
                        "type": "max_resize",
                        "max_width": 2000,
                        "max_height": 2000,
                    }
                ],
                "augmentation": True,
            },
            "device": {
                "use_ddp": False,  # Use DistributedDataParallel
                "ddp_port": "20027",
                "use_amp": True,  # Enable automatic mix-precision
                "nb_gpu": 0,
                "force_cpu": True,  # True for debug purposes
            },
            "metrics": {
                "train": [
                    "loss_ce",
                    "cer",
                    "wer",
                    "wer_no_punct",
                ],  # Metrics name for training
                "eval": [
                    "cer",
                    "wer",
                    "wer_no_punct",
                ],  # Metrics name for evaluation on validation set during training
            },
            "validation": {
                "eval_on_valid": True,  # Whether to eval and logs metrics on validation set during training or not
                "eval_on_valid_interval": 2,  # Interval (in epochs) to evaluate during training
                "set_name_focus_metric": "training-val",
            },
            "output_folder": "dan_trained_model",  # folder name for checkpoint and results
            "gradient_clipping": {},
            "max_nb_epochs": 4,  # maximum number of epochs before to stop
            "load_epoch": "last",  # ["best", "last"]: last to continue training, best to evaluate
            "optimizers": {
                "all": {
                    "args": {
                        "lr": 0.0001,
                        "amsgrad": False,
                    },
                },
            },
            "lr_schedulers": None,  # Learning rate schedulers
            # Keep teacher forcing rate to 20% during whole training
            "label_noise_scheduler": {
                "min_error_rate": 0.2,
                "max_error_rate": 0.2,
                "total_num_steps": 5e4,
            },
            "transfer_learning": None,
        },
    }
    update_config(config)
    return config


@pytest.fixture
def prediction_data_path():
    return FIXTURES / "prediction"