Skip to content
Snippets Groups Projects
arguments.py 6 KiB
Newer Older
Solene Tarride's avatar
Solene Tarride committed
# -*- coding: utf-8 -*-
import getpass
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional

USER = getpass.getuser()


class Style(Enum):
    handwritten: str = "handwritten"
    typewritten: str = "typewritten"
    other: str = "other"


class ExtractionMode(Enum):
    boundingRect: str = "boundingRect"
    min_area_rect: str = "min_area_rect"
    deskew_min_area_rect: str = "deskew_min_area_rect"
    skew_min_area_rect: str = "skew_min_area_rect"
    polygon: str = "polygon"
    skew_polygon: str = "skew_polygon"
    deskew_polygon: str = "deskew_polygon"


class TranscriptionType(Enum):
    word: str = "word"
    text_line: str = "text_line"
    half_subject_line: str = "half_subject_line"
    text_zone: str = "text_zone"
    paragraph: str = "paragraph"
    act: str = "act"
    page: str = "page"
    text: str = "text"
Solene Tarride's avatar
Solene Tarride committed


@dataclass
class SelectArgs:
    """
    Arguments to select elements from Arkindex

    Args:
        corpora (list): List of corpus ids to be used.
        volumes (list): List of volume ids to be used.
        folders (list): List of folder ids to be used. Elements of `volume_type` will be searched recursively in these folders
        pages (list): List of page ids to be used.
        selection (bool): Get elements from selection
        volume_type (str): Volumes (1 level above page) may have a different name on corpora
    """

    corpora: Optional[List[str]] = field(default_factory=list)
    volumes: Optional[List[str]] = field(default_factory=list)
    folders: Optional[List[str]] = field(default_factory=list)
    pages: Optional[List[str]] = field(default_factory=list)
    selection: bool = False
    volume_type: str = "volume"


@dataclass
class CommonArgs:
    """
    General arguments

    Args:
        cache_dir (str): Cache directory where to save the full size downloaded images.
        log_parameters (bool): Save every parameters to a JSON file.
    """

    cache_dir: str = f"/tmp/kaldi_data_generator_{USER}/cache/"
    log_parameters: bool = True


@dataclass
class SplitArgs:
    """
    Arguments related to data splitting into training, validation and test subsets.

    Args:
        train_ratio (float): Ratio of data to be used in the training set. Should be between 0 and 1.
        test_ratio (float): Ratio of data to be used in the testing set. Should be between 0 and 1.
        val_ratio (float): Ratio of data to be used in the validation set. The sum of three variables should equal 1.
        use_existing_split (bool): Use an existing split instead of random. Expecting line_ids to be prefixed with (train, val and test).
        split_only (bool): Create the split from already downloaded lines, don't download the lines
        no_split (bool): No splitting of the data to be done just download the line in the right format
    """

    train_ratio: float = 0.8
    test_ratio: float = 0.1
    val_ratio: float = 1 - train_ratio - test_ratio
    use_existing_split: bool = False
    split_only: bool = False
    no_split: bool = False


@dataclass
class ImageArgs:
    """
    Arguments related to image transformation.

    Args:
        extraction_mode: Mode for extracting the line images: {[e.name for e in Extraction]},
        max_deskew_angle: Maximum angle by which deskewing is allowed to rotate the line image.
            the angle determined by deskew tool is bigger than max then that line won't be deskewed/rotated.
Yoann Schneider's avatar
Yoann Schneider committed
        skew_angle: Angle by which the line image will be rotated. Useful for data augmentation"
Solene Tarride's avatar
Solene Tarride committed
            creating skewed text lines for a more robust model. Only used with skew_* extraction modes.
        should_rotate (bool): Use text line rotation class to rotate lines if possible
        grayscale (bool): Convert images to grayscale (By default grayscale)
        scale_x (float): Ratio of how much to scale the polygon horizontally (1.0 means no rescaling)
        scale_y_top (float): Ratio of how much to scale the polygon vertically on the top (1.0 means no rescaling)
        scale_y_bottom (float): Ratio of how much to scale the polygon vertically on the bottom (1.0 means no rescaling)
    """

    extraction_mode: ExtractionMode = ExtractionMode.deskew_min_area_rect
    max_deskew_angle: int = 45
    skew_angle: int = 0
    should_rotate: bool = False
    grayscale: bool = True
    scale_x: Optional[float] = None
    scale_y_top: Optional[float] = None
    scale_y_bottom: Optional[float] = None


@dataclass
class FilterArgs:
    """
    Arguments related to element filtering.

    Args:
        transcription_type: Which type of elements' transcriptions to use? (page, paragraph, text_line, etc)
        ignored_classes: List of ignored ml_class names. Filter lines by class
        accepted_classes: List of accepted ml_class names. Filter lines by class
        accepted_worker_version_ids:List of accepted worker version ids. Filter transcriptions by worker version ids.
            The order is important - only up to one transcription will be chosen per element (text_line)
            and the worker version order defines the precedence. If there exists a transcription for
            the first worker version then it will be chosen, otherwise will continue on to the next
            worker version.
            Use `--accepted_worker_version_ids manual` to get only manual transcriptions
        style: Filter line images by style class. 'other' corresponds to line elements that have neither
            handwritten or typewritten class : {[s.name for s in Style]}
        accepted_metadatas: Key-value dictionary where each entry is a mandatory Arkindex metadata name/value. Filter lines by metadata.
Solene Tarride's avatar
Solene Tarride committed
    """

    transcription_type: TranscriptionType = TranscriptionType.text_line
    ignored_classes: List[str] = field(default_factory=list)
    accepted_classes: List[str] = field(default_factory=list)
    accepted_worker_version_ids: List[str] = field(default_factory=list)
    skip_vertical_lines: bool = False
    style: Style = None
    accepted_metadatas: dict = field(default_factory=dict)
    filter_parent_metadatas: bool = False