# -*- coding: utf-8 -*- """ Data extraction """ import uuid from pathlib import Path from typing import Optional from teklia_toolbox.config import ConfigParser from atr_data_generator.arguments import CommonArgs from atr_data_generator.extract.arguments import ( DEFAULT_RESCALE, ExtractionMode, FilterArgs, ImageArgs, SelectArgs, ) from atr_data_generator.extract.base import DataGenerator from atr_data_generator.extract.pylaia.arguments import PylaiaArgs from atr_data_generator.extract.pylaia.main import PylaiaDataGenerator from atr_data_generator.extract.utils import ListedEnum from atr_data_generator.split.arguments import SplitArgs class Generators(ListedEnum): """ List of supported specific formatters. """ pylaia = PylaiaDataGenerator def _float(value): if value is None: return None return float(value) def get_parser(): parser = ConfigParser() # Common arguments common = parser.add_subparser("common") common.add_option("dataset_name", type=str) common.add_option("output_dir", type=Path) common.add_option("cache_dir", type=Path, default=Path(".cache")) common.add_option("log_parameters", type=bool, default=True) # Image arguments image = parser.add_subparser("image", default={}) image.add_option( "extraction_mode", type=ExtractionMode, default=ExtractionMode.deskew_min_area_rect, ) image.add_option("fixed_height", type=int, default=None) image.add_option("max_deskew_angle", type=int, default=45) image.add_option("skew_angle", type=int, default=0) image.add_option("should_rotate", type=bool, default=False) image.add_option("grayscale", type=bool, default=True) scale = image.add_subparser("scale", default={}) scale.add_option("x", type=_float, default=DEFAULT_RESCALE) scale.add_option("y_top", type=_float, default=DEFAULT_RESCALE) scale.add_option("y_bottom", type=_float, default=DEFAULT_RESCALE) # Filters filters = parser.add_subparser("filter", default={}) filters.add_option("accepted_worker_version_ids", type=str, many=True, default=[]) filters.add_option("skip_vertical_lines", type=bool, default=False) # Select select = parser.add_subparser("select", default={}) select.add_option("parent_type", type=str, default=None) select.add_option("element_type", type=str, default=None) # Split split = parser.add_subparser("split", default={}) split.add_option("train_folder", type=uuid.UUID, default=None) split.add_option("validation_folder", type=uuid.UUID, default=None) split.add_option("test_folder", type=uuid.UUID, default=None) # Format specific # Pylaia pylaia = parser.add_subparser("pylaia", default={}) pylaia.add_option("syms_path", type=Path, default=None) return parser def config_parser(configuration_path: Path): """ Returns parsed - CommonArgs - ImageArgs - FilterArgs - SelectArgs - SplitArgs # Format specific args if provided """ config_data = get_parser().parse(configuration_path) return { "common": CommonArgs(**config_data["common"]), "image": ImageArgs(**config_data["image"]), "filter": FilterArgs(**config_data["filter"]), "select": SelectArgs(**config_data["select"]), "split": SplitArgs(**config_data["split"]), # Format specific "pylaia": PylaiaArgs(**config_data["pylaia"]), } def add_extract_subparser(subcommands): parser = subcommands.add_parser( "extract", description=__doc__, help=__doc__, ) parser.add_argument("--config", type=Path, help="Configuration file") parser.add_argument("--database-path", type=Path, help="Export path") parser.add_argument( "--format", type=str, choices=Generators.list(), help="Format of the dataset." ) parser.set_defaults(func=main, config_parser=config_parser) def main( database_path: Path, format: Optional[str], **kwargs, ): data_generator = DataGenerator if format is None else Generators[format].value data_generator(**kwargs).run(db_path=database_path)