Skip to content
Snippets Groups Projects
arguments.py 7.62 KiB
Newer Older
Yoann Schneider's avatar
Yoann Schneider committed
# -*- coding: utf-8 -*-
import argparse
import getpass
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Union

from line_image_extractor.extractor import Extraction

from atr_data_generator.arguments import BaseArgs

USER = getpass.getuser()
MANUAL = "manual"
DEFAULT_RESCALE = 1.0


class ListedEnum(str, Enum):
    @classmethod
    def list(cls):
        return [choice.value for choice in cls]

    def __str__(self) -> str:
        return self.value


class Style(ListedEnum):
    """Font style of the text.

    Attributes:
        handwritten: Handwritten font style.
        typewritten: Printed font style.
        other: Other font style.
    """

    handwritten: str = "handwritten"
    typewritten: str = "typewritten"
    other: str = "other"


class ExtractionMode(ListedEnum):
    """Extraction mode of the image.

    Attributes:
        boundingRect:
        min_area_rect:
        deskew_min_area_rect:
        skew_min_area_rect:
        polygon:
        skew_polygon:
        deskew_polygon:
    """

    boundingRect: str = "boundingRect"
    min_area_rect: str = "min_area_rect"
    deskew_min_area_rect: str = "deskew_min_area_rect"
    skew_min_area_rect: str = "skew_min_area_rect"
    polygon: str = "polygon"
    skew_polygon: str = "skew_polygon"
    deskew_polygon: str = "deskew_polygon"

    @property
    def mode(self):
        return getattr(Extraction, self.value)


class TranscriptionType(ListedEnum):
    """Arkindex type of the element to extract.

    Attributes:
        word:
        text_line:
        half_subject_line:
        text_zone:
        paragraph:
        act:
        page:
        text:
    """

    word: str = "word"
    text_line: str = "text_line"
    half_subject_line: str = "half_subject_line"
    text_zone: str = "text_zone"
    paragraph: str = "paragraph"
    act: str = "act"
    page: str = "page"
    text: str = "text"


@dataclass
class SelectArgs(BaseArgs):
    """
    Arguments to select elements from Arkindex

    Args:
        corpora (list): List of corpus ids to be used.
        folders (list): List of folder ids to be used. Elements of `volume_type` will be searched recursively in these folders
        element_type (str): Filter elements to process by type
        parent_type (str): Filter elements parents to process by type
    """

    corpora: Optional[List[str]] = field(default_factory=list)
    folders: Optional[List[str]] = field(default_factory=list)
    element_type: Optional[str] = None
    parent_type: Optional[str] = None


@dataclass
class ScaleArgs(BaseArgs):
    """Scale the polygon if needed.

    Args:
        x (float): Ratio of how much to scale the polygon horizontally (1.0 means no rescaling)
        y_top (float): Ratio of how much to scale the polygon vertically on the top (1.0 means no rescaling)
        y_bottom (float): Ratio of how much to scale the polygon vertically on the bottom (1.0 means no rescaling)
    """

    x: Optional[float] = None
    y_top: Optional[float] = None
    y_bottom: Optional[float] = None

    def __post_init__(self):
        super().__post_init__()
        if self.should_resize_polygons:
            # use 1.0 as default - no resize, if not specified
            self.x = self.y_top = self.y_bottom = DEFAULT_RESCALE

    @property
    def should_resize_polygons(self):
        return self.x or self.y_top or self.y_bottom


@dataclass
class ImageArgs(BaseArgs):
    """
    Arguments related to image transformation.

    Args:
        extraction_mode: Mode for extracting the line images: {[e.name for e in Extraction]},
        max_deskew_angle: Maximum angle by which deskewing is allowed to rotate the line image.
            the angle determined by deskew tool is bigger than max then that line won't be deskewed/rotated.
        skew_angle: Angle by which the line image will be rotated. Useful for data augmentation"
            creating skewed text lines for a more robust model. Only used with skew_* extraction modes.
        should_rotate (bool): Use text line rotation class to rotate lines if possible
        grayscale (bool): Convert images to grayscale (By default grayscale)
    """

    scale: Union[ScaleArgs, Dict[str, str]] = field(default_factory=dict)
    extraction_mode: ExtractionMode = ExtractionMode.deskew_min_area_rect
    max_deskew_angle: int = 45
    skew_angle: int = 0
    should_rotate: bool = False
    grayscale: bool = True

    def __post_init__(self):
        self.scale = ScaleArgs(**self.scale)
        super().__post_init__()


@dataclass
class FilterArgs(BaseArgs):
    """
    Arguments related to element filtering.

    Args:
        transcription_type: Which type of elements' transcriptions to use? (page, paragraph, text_line, etc)
        ignored_classes: List of ignored ml_class names. Filter lines by class
        accepted_classes: List of accepted ml_class names. Filter lines by class
        accepted_worker_version_ids:List of accepted worker version ids. Filter transcriptions by worker version ids.
            The order is important - only up to one transcription will be chosen per element (text_line)
            and the worker version order defines the precedence. If there exists a transcription for
            the first worker version then it will be chosen, otherwise will continue on to the next
            worker version.
            Use `--accepted_worker_version_ids manual` to get only manual transcriptions
        style: Filter line images by style class.
        accepted_metadatas: Key-value dictionary where each entry is a mandatory Arkindex metadata name/value. Filter lines by metadata.
    """

    transcription_type: TranscriptionType = TranscriptionType.text_line
    ignored_classes: List[str] = field(default_factory=list)
    accepted_classes: List[str] = field(default_factory=list)
    accepted_worker_version_ids: List[str] = field(default_factory=list)
    skip_vertical_lines: bool = False
    style: Optional[Style] = None
    accepted_metadatas: dict = field(default_factory=dict)
    filter_parent_metadatas: bool = False

    def _validate(self):
        # Check overlap of accepted and ignored classes
        accepted_classes = self.accepted_classes
        ignored_classes = self.accepted_classes
        if accepted_classes and ignored_classes:
            if set(accepted_classes) & set(ignored_classes):
                raise argparse.ArgumentTypeError(
                    f"--filter.accepted_classes and --filter.ignored_classes values must not overlap ({accepted_classes} - {ignored_classes})"
                )

        if self.style and (accepted_classes or ignored_classes):
            if set(Style.list()) & (set(accepted_classes) | set(ignored_classes)):
                raise argparse.ArgumentTypeError(
                    f"--style class values ({Style.list()}) shouldn't be in the accepted_classes list "
                    f"(or ignored_classes list) "
                    "if --filter.style is used with either --filter.accepted_classes or --filter.ignored_classes."
                )

        # Support manual worker_version_ids
        if MANUAL in self.accepted_worker_version_ids:
            # Replace with None
            self.accepted_worker_version_ids[
                self.accepted_worker_version_ids.index(MANUAL)
            ] = None

    @property
    def should_filter_by_class(self):
        return bool(self.accepted_classes) or bool(self.ignored_classes)

    @property
    def should_filter_by_worker(self):
        return bool(self.accepted_worker_version_ids)

    @property
    def should_filter_by_style(self):
        return bool(self.style)

    @property
    def should_filter_by_metadatas(self):
        return bool(self.accepted_metadatas)