Skip to content
Snippets Groups Projects
arguments.py 4.55 KiB
Newer Older
Yoann Schneider's avatar
Yoann Schneider committed
# -*- coding: utf-8 -*-
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
Yoann Schneider's avatar
Yoann Schneider committed
from uuid import UUID
Yoann Schneider's avatar
Yoann Schneider committed

from line_image_extractor.extractor import Extraction

from atr_data_generator.arguments import BaseArgs
Yoann Schneider's avatar
Yoann Schneider committed
from atr_data_generator.extract.utils import ListedEnum
Yoann Schneider's avatar
Yoann Schneider committed

MANUAL = "manual"
DEFAULT_RESCALE = 1.0


class ExtractionMode(ListedEnum):
    """Extraction mode of the image.

    Attributes:
        boundingRect:
        min_area_rect:
        deskew_min_area_rect:
        skew_min_area_rect:
        polygon:
        skew_polygon:
        deskew_polygon:
    """

    boundingRect: str = "boundingRect"
    min_area_rect: str = "min_area_rect"
    deskew_min_area_rect: str = "deskew_min_area_rect"
    skew_min_area_rect: str = "skew_min_area_rect"
    polygon: str = "polygon"
    skew_polygon: str = "skew_polygon"
    deskew_polygon: str = "deskew_polygon"

    @property
    def mode(self):
        return getattr(Extraction, self.value)


@dataclass
class SelectArgs(BaseArgs):
    """
    Arguments to select elements from Arkindex

    Args:
        element_type (str): Filter elements to process by type
        parent_type (str): Filter elements parents to process by type
    """

    element_type: Optional[str] = None
    parent_type: Optional[str] = None


@dataclass
class ScaleArgs(BaseArgs):
    """Scale the polygon if needed.

    Args:
        x (float): Ratio of how much to scale the polygon horizontally (1.0 means no rescaling)
        y_top (float): Ratio of how much to scale the polygon vertically on the top (1.0 means no rescaling)
        y_bottom (float): Ratio of how much to scale the polygon vertically on the bottom (1.0 means no rescaling)
    """

    x: float = DEFAULT_RESCALE
    y_top: float = DEFAULT_RESCALE
    y_bottom: float = DEFAULT_RESCALE
Yoann Schneider's avatar
Yoann Schneider committed

    def __post_init__(self):
        if self.should_resize_polygons:
            # use 1.0 as default - no resize, if not specified
            self.x = self.y_top = self.y_bottom = DEFAULT_RESCALE

    @property
    def should_resize_polygons(self):
        return bool(
            self.x != DEFAULT_RESCALE
            or self.y_top != DEFAULT_RESCALE
            or self.y_bottom != DEFAULT_RESCALE
        )
Yoann Schneider's avatar
Yoann Schneider committed


@dataclass
class ImageArgs(BaseArgs):
    """
    Arguments related to image transformation.

    Args:
        extraction_mode (str): Mode for extracting the line images, see `ExtractionMode` class for options,
        max_deskew_angle (int): Maximum angle by which deskewing is allowed to rotate the line image.
Yoann Schneider's avatar
Yoann Schneider committed
            the angle determined by deskew tool is bigger than max then that line won't be deskewed/rotated.
        skew_angle -int): Angle by which the line image will be rotated. Useful for data augmentation"
Yoann Schneider's avatar
Yoann Schneider committed
            creating skewed text lines for a more robust model. Only used with skew_* extraction modes.
        should_rotate (bool): Use text line rotation class to rotate lines if possible
        grayscale (bool): Convert images to grayscale (By default grayscale)
        fixed_height (int): Resize images to a fixed height.
Yoann Schneider's avatar
Yoann Schneider committed
    """

    scale: Union[ScaleArgs, Dict[str, str]] = field(default_factory=dict)
    extraction_mode: ExtractionMode = ExtractionMode.deskew_min_area_rect
    max_deskew_angle: int = 45
    skew_angle: int = 0
    should_rotate: bool = False
    grayscale: bool = True
    fixed_height: Optional[int] = None
Yoann Schneider's avatar
Yoann Schneider committed

    def __post_init__(self):
        self.scale = ScaleArgs(**self.scale)

    def json(self):
        data = super().json()
        data.update(
            {
                "extraction_mode": self.extraction_mode.value,
                "scale": self.scale.json(),
            }
        )
        return data
Yoann Schneider's avatar
Yoann Schneider committed


@dataclass
class FilterArgs(BaseArgs):
    """
    Arguments related to element filtering.

    Args:
        accepted_worker_version_ids (List of str): List of accepted worker version ids. Filter transcriptions by worker version ids.
            Use `manual` to get only manual transcriptions
        skip_vertical_lines (bool): Skip vertical lines.
Yoann Schneider's avatar
Yoann Schneider committed
    """

    accepted_worker_version_ids: List[str] = field(default_factory=list)
    skip_vertical_lines: bool = False

Yoann Schneider's avatar
Yoann Schneider committed
    def __post_init__(self):
        # Validate given worker version IDs
        for source in self.accepted_worker_version_ids:
            if source != MANUAL:
                assert UUID(source)

    def json(self):
        data = super().json()
        data.update(
            {
                "accepted_worker_version_ids": list(
                    map(str, self.accepted_worker_version_ids)
Yoann Schneider's avatar
Yoann Schneider committed
                )
            }
        )
        return data