# -*- coding: utf-8 -*-
import os
import pickle
import random

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from dan.manager.dataset import OCRDataset
from dan.transforms import (
    get_augmentation_transforms,
    get_normalization_transforms,
    get_preprocessing_transforms,
)
from dan.utils import pad_images, pad_sequences_1D


class OCRDatasetManager:
    def __init__(self, params, device: str):
        self.params = params

        # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
        self.pin_memory = device != "cpu"

        self.train_dataset = None
        self.valid_datasets = dict()
        self.test_datasets = dict()

        self.train_loader = None
        self.valid_loaders = dict()
        self.test_loaders = dict()

        self.train_sampler = None
        self.valid_samplers = dict()
        self.test_samplers = dict()

        self.generator = torch.Generator()
        self.generator.manual_seed(0)

        self.load_in_memory = (
            self.params["config"]["load_in_memory"]
            if "load_in_memory" in self.params["config"]
            else True
        )
        self.charset = self.get_charset()
        self.tokens = self.get_tokens()
        self.params["config"]["padding_token"] = self.tokens["pad"]

        self.my_collate_function = OCRCollateFunction(self.params["config"])
        self.normalization = get_normalization_transforms()
        self.augmentation = (
            get_augmentation_transforms()
            if self.params["config"]["augmentation"]
            else None
        )
        self.preprocessing = get_preprocessing_transforms(
            params["config"]["preprocessings"], to_pil_image=True
        )

    def load_datasets(self):
        """
        Load training and validation datasets
        """
        self.train_dataset = OCRDataset(
            set_name="train",
            paths_and_sets=self.get_paths_and_sets(self.params["train"]["datasets"]),
            charset=self.charset,
            tokens=self.tokens,
            preprocessing_transforms=self.preprocessing,
            normalization_transforms=self.normalization,
            augmentation_transforms=self.augmentation,
            load_in_memory=self.load_in_memory,
        )

        for custom_name in self.params["val"].keys():
            self.valid_datasets[custom_name] = OCRDataset(
                set_name="val",
                paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]),
                charset=self.charset,
                tokens=self.tokens,
                preprocessing_transforms=self.preprocessing,
                normalization_transforms=self.normalization,
                augmentation_transforms=None,
                load_in_memory=self.load_in_memory,
            )

    def load_ddp_samplers(self):
        """
        Load training and validation data samplers
        """
        if self.params["use_ddp"]:
            self.train_sampler = DistributedSampler(
                self.train_dataset,
                num_replicas=self.params["num_gpu"],
                rank=self.params["ddp_rank"],
                shuffle=True,
            )
            for custom_name in self.valid_datasets.keys():
                self.valid_samplers[custom_name] = DistributedSampler(
                    self.valid_datasets[custom_name],
                    num_replicas=self.params["num_gpu"],
                    rank=self.params["ddp_rank"],
                    shuffle=False,
                )
        else:
            for custom_name in self.valid_datasets.keys():
                self.valid_samplers[custom_name] = None

    def load_dataloaders(self):
        """
        Load training and validation data loaders
        """
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.params["batch_size"],
            shuffle=True if self.train_sampler is None else False,
            drop_last=False,
            batch_sampler=self.train_sampler,
            sampler=self.train_sampler,
            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
            pin_memory=self.pin_memory,
            collate_fn=self.my_collate_function,
            worker_init_fn=self.seed_worker,
            generator=self.generator,
        )

        for key in self.valid_datasets.keys():
            self.valid_loaders[key] = DataLoader(
                self.valid_datasets[key],
                batch_size=1,
                sampler=self.valid_samplers[key],
                batch_sampler=self.valid_samplers[key],
                shuffle=False,
                num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
                pin_memory=self.pin_memory,
                drop_last=False,
                collate_fn=self.my_collate_function,
                worker_init_fn=self.seed_worker,
                generator=self.generator,
            )

    @staticmethod
    def seed_worker():
        """
        Set worker seed
        """
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    def generate_test_loader(self, custom_name, sets_list):
        """
        Load test dataset, data sampler and data loader
        """
        if custom_name in self.test_loaders.keys():
            return
        paths_and_sets = list()
        for set_info in sets_list:
            paths_and_sets.append(
                {"path": self.params["datasets"][set_info[0]], "set_name": set_info[1]}
            )
        self.test_datasets[custom_name] = OCRDataset(
            set_name="test",
            paths_and_sets=paths_and_sets,
            charset=self.charset,
            tokens=self.tokens,
            preprocessing_transforms=self.preprocessing,
            normalization_transforms=self.normalization,
            augmentation_transforms=None,
            load_in_memory=self.load_in_memory,
        )

        if self.params["use_ddp"]:
            self.test_samplers[custom_name] = DistributedSampler(
                self.test_datasets[custom_name],
                num_replicas=self.params["num_gpu"],
                rank=self.params["ddp_rank"],
                shuffle=False,
            )
        else:
            self.test_samplers[custom_name] = None
        self.test_loaders[custom_name] = DataLoader(
            self.test_datasets[custom_name],
            batch_size=1,
            sampler=self.test_samplers[custom_name],
            shuffle=False,
            num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
            pin_memory=self.pin_memory,
            drop_last=False,
            collate_fn=self.my_collate_function,
            worker_init_fn=self.seed_worker,
            generator=self.generator,
        )

    def get_paths_and_sets(self, dataset_names_folds):
        """
        Set the right path for each data set
        """
        paths_and_sets = list()
        for dataset_name, fold in dataset_names_folds:
            path = self.params["datasets"][dataset_name]
            paths_and_sets.append({"path": path, "set_name": fold})
        return paths_and_sets

    def get_charset(self):
        """
        Merge the charset of the different datasets used
        """
        if "charset" in self.params:
            return self.params["charset"]
        datasets = self.params["datasets"]
        charset = set()
        for key in datasets.keys():
            with open(os.path.join(datasets[key], "charset.pkl"), "rb") as f:
                charset = charset.union(set(pickle.load(f)))
        if "" in charset:
            charset.remove("")
        return sorted(list(charset))

    def get_tokens(self):
        """
        Get special tokens
        """
        return {
            "end": len(self.charset),
            "start": len(self.charset) + 1,
            "pad": len(self.charset) + 2,
        }


class OCRCollateFunction:
    """
    Merge samples data to mini-batch data for OCR task
    """

    def __init__(self, config):
        self.label_padding_value = config["padding_token"]
        self.config = config

    def __call__(self, batch_data):
        labels = [batch_data[i]["token_label"] for i in range(len(batch_data))]
        labels = pad_sequences_1D(labels, padding_value=self.label_padding_value).long()

        imgs = [batch_data[i]["img"] for i in range(len(batch_data))]
        imgs = pad_images(imgs)

        formatted_batch_data = {
            formatted_key: [batch_data[i][initial_key] for i in range(len(batch_data))]
            for formatted_key, initial_key in zip(
                [
                    "names",
                    "labels_len",
                    "raw_labels",
                    "imgs_position",
                    "imgs_reduced_shape",
                ],
                ["name", "label_len", "label", "img_position", "img_reduced_shape"],
            )
        }

        formatted_batch_data.update(
            {
                "imgs": imgs,
                "labels": labels,
            }
        )

        return formatted_batch_data