Skip to content
Snippets Groups Projects
arguments.py 1 KiB
Newer Older
Solene Tarride's avatar
Solene Tarride committed
# -*- coding: utf-8 -*-
from dataclasses import dataclass
Yoann Schneider's avatar
Yoann Schneider committed
from pathlib import Path
Solene Tarride's avatar
Solene Tarride committed


@dataclass
Yoann Schneider's avatar
Yoann Schneider committed
class BaseArgs:
    def json(self):
        return vars(self).copy()
Solene Tarride's avatar
Solene Tarride committed


@dataclass
Yoann Schneider's avatar
Yoann Schneider committed
class CommonArgs(BaseArgs):
Solene Tarride's avatar
Solene Tarride committed
    """
    General arguments

    Args:
Yoann Schneider's avatar
Yoann Schneider committed
        dataset_name (str): Name of the dataset being created
        output_dir (Path): Where the data should be generated
        cache_dir (Path): Cache directory where to save the full size downloaded images.
Solene Tarride's avatar
Solene Tarride committed
        log_parameters (bool): Save every parameters to a JSON file.
    """

Yoann Schneider's avatar
Yoann Schneider committed
    dataset_name: str
    output_dir: Path
    cache_dir: Path = Path(".cache")
Solene Tarride's avatar
Solene Tarride committed
    log_parameters: bool = True

Yoann Schneider's avatar
Yoann Schneider committed
    def __post_init__(self):
        self.output_dir.mkdir(exist_ok=True, parents=True)
        self.cache_dir.mkdir(exist_ok=True, parents=True)

    def json(self):
        data = super().json()
        data.update(
            {
                "output_dir": str(self.output_dir),
                "cache_dir": str(self.cache_dir),
            }
        )
        return data