# -*- coding: utf-8 -*-
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
from uuid import UUID

from line_image_extractor.extractor import Extraction

from atr_data_generator.arguments import BaseArgs
from atr_data_generator.extract.utils import ListedEnum

MANUAL = "manual"
DEFAULT_RESCALE = 1.0


class ExtractionMode(ListedEnum):
    """Extraction mode of the image.

    Attributes:
        boundingRect:
        min_area_rect:
        deskew_min_area_rect:
        skew_min_area_rect:
        polygon:
        skew_polygon:
        deskew_polygon:
    """

    boundingRect: str = "boundingRect"
    min_area_rect: str = "min_area_rect"
    deskew_min_area_rect: str = "deskew_min_area_rect"
    skew_min_area_rect: str = "skew_min_area_rect"
    polygon: str = "polygon"
    skew_polygon: str = "skew_polygon"
    deskew_polygon: str = "deskew_polygon"

    @property
    def mode(self):
        return getattr(Extraction, self.value)


@dataclass
class SelectArgs(BaseArgs):
    """
    Arguments to select elements from Arkindex

    Args:
        dataset (str): Filter dataset to process
        element_type (str): Filter elements to process by type
    """

    dataset: str
    element_type: Optional[str] = None

    def __post_init__(self):
        assert UUID(self.dataset)
        # Configuration parser issue: https://gitlab.teklia.com/tools/python-toolbox/-/issues/2
        if self.element_type == "None":
            self.element_type = None


@dataclass
class ScaleArgs(BaseArgs):
    """Scale the polygon if needed.

    Args:
        x (float): Ratio of how much to scale the polygon horizontally (1.0 means no rescaling)
        y_top (float): Ratio of how much to scale the polygon vertically on the top (1.0 means no rescaling)
        y_bottom (float): Ratio of how much to scale the polygon vertically on the bottom (1.0 means no rescaling)
    """

    x: float = DEFAULT_RESCALE
    y_top: float = DEFAULT_RESCALE
    y_bottom: float = DEFAULT_RESCALE

    def __post_init__(self):
        if self.should_resize_polygons:
            # use 1.0 as default - no resize, if not specified
            self.x = self.y_top = self.y_bottom = DEFAULT_RESCALE

    @property
    def should_resize_polygons(self):
        return bool(
            self.x != DEFAULT_RESCALE
            or self.y_top != DEFAULT_RESCALE
            or self.y_bottom != DEFAULT_RESCALE
        )


@dataclass
class ImageArgs(BaseArgs):
    """
    Arguments related to image transformation.

    Args:
        extraction_mode (str): Mode for extracting the line images, see `ExtractionMode` class for options,
        max_deskew_angle (int): Maximum angle by which deskewing is allowed to rotate the line image.
            the angle determined by deskew tool is bigger than max then that line won't be deskewed/rotated.
        skew_angle -int): Angle by which the line image will be rotated. Useful for data augmentation"
            creating skewed text lines for a more robust model. Only used with skew_* extraction modes.
        should_rotate (bool): Use text line rotation class to rotate lines if possible
        grayscale (bool): Convert images to grayscale (By default grayscale)
        fixed_height (int): Resize images to a fixed height.
    """

    scale: Union[ScaleArgs, Dict[str, str]] = field(default_factory=dict)
    extraction_mode: ExtractionMode = ExtractionMode.deskew_min_area_rect
    max_deskew_angle: int = 45
    skew_angle: int = 0
    should_rotate: bool = False
    grayscale: bool = True
    fixed_height: Optional[int] = None

    def __post_init__(self):
        self.scale = ScaleArgs(**self.scale)

    def json(self):
        data = super().json()
        data.update(
            {
                "extraction_mode": self.extraction_mode.value,
                "scale": self.scale.json(),
            }
        )
        return data


@dataclass
class FilterArgs(BaseArgs):
    """
    Arguments related to element filtering.

    Args:
        accepted_worker_version_ids (List of str): List of accepted worker version ids. Filter transcriptions by worker version ids.
            Use `manual` to get only manual transcriptions
        skip_vertical_lines (bool): Skip vertical lines.
    """

    accepted_worker_version_ids: List[str] = field(default_factory=list)
    skip_vertical_lines: bool = False

    def __post_init__(self):
        # Validate given worker version IDs
        for source in self.accepted_worker_version_ids:
            if source != MANUAL:
                assert UUID(source)

    def json(self):
        data = super().json()
        data.update(
            {
                "accepted_worker_version_ids": list(
                    map(str, self.accepted_worker_version_ids)
                )
            }
        )
        return data