Skip to content
Snippets Groups Projects
arguments.py 1.01 KiB
Newer Older
Solene Tarride's avatar
Solene Tarride committed
# -*- coding: utf-8 -*-
Yoann Schneider's avatar
Yoann Schneider committed
import json
from dataclasses import asdict, dataclass
from pathlib import Path
Solene Tarride's avatar
Solene Tarride committed


@dataclass
Yoann Schneider's avatar
Yoann Schneider committed
class BaseArgs:
    def __post_init__(self):
        self._validate()
Yoann Schneider's avatar
Yoann Schneider committed
    def _validate(self):
        """Override this method to add argument validation."""
        pass
Yoann Schneider's avatar
Yoann Schneider committed
    def dict(self):
        return json.loads(json.dumps(asdict(self), default=str))
Solene Tarride's avatar
Solene Tarride committed


@dataclass
Yoann Schneider's avatar
Yoann Schneider committed
class CommonArgs(BaseArgs):
Solene Tarride's avatar
Solene Tarride committed
    """
    General arguments

    Args:
Yoann Schneider's avatar
Yoann Schneider committed
        dataset_name (str): Name of the dataset being created
        output_dir (Path): Where the data should be generated
        cache_dir (Path): Cache directory where to save the full size downloaded images.
Solene Tarride's avatar
Solene Tarride committed
        log_parameters (bool): Save every parameters to a JSON file.
    """

Yoann Schneider's avatar
Yoann Schneider committed
    dataset_name: str
    output_dir: Path
    cache_dir: Path = Path(".cache")
Solene Tarride's avatar
Solene Tarride committed
    log_parameters: bool = True

Yoann Schneider's avatar
Yoann Schneider committed
    def __post_init__(self):
        super().__post_init__()
        self.output_dir.mkdir(exist_ok=True, parents=True)
        self.cache_dir.mkdir(exist_ok=True, parents=True)