Skip to content
Snippets Groups Projects
arguments.py 4.41 KiB
Newer Older
Yoann Schneider's avatar
Yoann Schneider committed
# -*- coding: utf-8 -*-
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Union

from line_image_extractor.extractor import Extraction

from atr_data_generator.arguments import BaseArgs

MANUAL = "manual"
DEFAULT_RESCALE = 1.0


class ListedEnum(str, Enum):
    @classmethod
    def list(cls):
        return [choice.value for choice in cls]

    def __str__(self) -> str:
        return self.value


class ExtractionMode(ListedEnum):
    """Extraction mode of the image.

    Attributes:
        boundingRect:
        min_area_rect:
        deskew_min_area_rect:
        skew_min_area_rect:
        polygon:
        skew_polygon:
        deskew_polygon:
    """

    boundingRect: str = "boundingRect"
    min_area_rect: str = "min_area_rect"
    deskew_min_area_rect: str = "deskew_min_area_rect"
    skew_min_area_rect: str = "skew_min_area_rect"
    polygon: str = "polygon"
    skew_polygon: str = "skew_polygon"
    deskew_polygon: str = "deskew_polygon"

    @property
    def mode(self):
        return getattr(Extraction, self.value)


@dataclass
class SelectArgs(BaseArgs):
    """
    Arguments to select elements from Arkindex

    Args:
        element_type (str): Filter elements to process by type
        parent_type (str): Filter elements parents to process by type
    """

    element_type: Optional[str] = None
    parent_type: Optional[str] = None


@dataclass
class ScaleArgs(BaseArgs):
    """Scale the polygon if needed.

    Args:
        x (float): Ratio of how much to scale the polygon horizontally (1.0 means no rescaling)
        y_top (float): Ratio of how much to scale the polygon vertically on the top (1.0 means no rescaling)
        y_bottom (float): Ratio of how much to scale the polygon vertically on the bottom (1.0 means no rescaling)
    """

    x: float = DEFAULT_RESCALE
    y_top: float = DEFAULT_RESCALE
    y_bottom: float = DEFAULT_RESCALE
Yoann Schneider's avatar
Yoann Schneider committed

    def __post_init__(self):
        if self.should_resize_polygons:
            # use 1.0 as default - no resize, if not specified
            self.x = self.y_top = self.y_bottom = DEFAULT_RESCALE

    @property
    def should_resize_polygons(self):
        return bool(
            self.x != DEFAULT_RESCALE
            or self.y_top != DEFAULT_RESCALE
            or self.y_bottom != DEFAULT_RESCALE
        )
Yoann Schneider's avatar
Yoann Schneider committed


@dataclass
class ImageArgs(BaseArgs):
    """
    Arguments related to image transformation.

    Args:
        extraction_mode (str): Mode for extracting the line images, see `ExtractionMode` class for options,
        max_deskew_angle (int): Maximum angle by which deskewing is allowed to rotate the line image.
Yoann Schneider's avatar
Yoann Schneider committed
            the angle determined by deskew tool is bigger than max then that line won't be deskewed/rotated.
        skew_angle -int): Angle by which the line image will be rotated. Useful for data augmentation"
Yoann Schneider's avatar
Yoann Schneider committed
            creating skewed text lines for a more robust model. Only used with skew_* extraction modes.
        should_rotate (bool): Use text line rotation class to rotate lines if possible
        grayscale (bool): Convert images to grayscale (By default grayscale)
    """

    scale: Union[ScaleArgs, Dict[str, str]] = field(default_factory=dict)
    extraction_mode: ExtractionMode = ExtractionMode.deskew_min_area_rect
    max_deskew_angle: int = 45
    skew_angle: int = 0
    should_rotate: bool = False
    grayscale: bool = True
    fixed_height: Optional[int] = None
Yoann Schneider's avatar
Yoann Schneider committed

    def __post_init__(self):
        self.scale = ScaleArgs(**self.scale)

    def json(self):
        data = super().json()
        data.update(
            {
                "extraction_mode": self.extraction_mode.value,
                "scale": self.scale.json(),
            }
        )
        return data
Yoann Schneider's avatar
Yoann Schneider committed


@dataclass
class FilterArgs(BaseArgs):
    """
    Arguments related to element filtering.

    Args:
        accepted_worker_version_ids (List of str): List of accepted worker version ids. Filter transcriptions by worker version ids.
            Use `manual` to get only manual transcriptions
        skip_vertical_lines (bool): Skip vertical lines.
Yoann Schneider's avatar
Yoann Schneider committed
    """

    accepted_worker_version_ids: List[str] = field(default_factory=list)
    skip_vertical_lines: bool = False

    def json(self):
        data = super().json()
        data.update(
            {
                "accepted_worker_version_ids": list(
                    map(str, self.accepted_worker_version_ids)
Yoann Schneider's avatar
Yoann Schneider committed
                )
            }
        )
        return data