From cca26d5e67448b3035bca70e0eff3b713e78dc2c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Fri, 22 Jul 2022 11:24:06 +0000
Subject: [PATCH] Export parameters json

---
 README.md                                     | 112 +++-
 kaldi_data_generator/arguments.py             | 146 +++++
 kaldi_data_generator/main.py                  | 536 +++++++-----------
 kaldi_data_generator/utils.py                 |  76 +++
 requirements.txt                              |   2 +-
 .../pinned_insects/partitions/TestLines.lst   |   8 +-
 .../partitions/ValidationLines.lst            |   8 +-
 tests/test_main.py                            |  16 +-
 8 files changed, 512 insertions(+), 392 deletions(-)
 create mode 100644 kaldi_data_generator/arguments.py

diff --git a/README.md b/README.md
index d70bf7a..53c20b8 100644
--- a/README.md
+++ b/README.md
@@ -1,41 +1,102 @@
-### Kaldi and kraken training data generator
+# Kaldi and kraken training data generator
 
 This script downloads pages with transcriptions from Arkindex
-and converts data to Kaldi format or kraken format. It also generates train, val and test splits.
+and converts data to Kaldi format or kraken format.
+It also generates reproducible train, val and test splits.
 
-### Using the script
+## Usage
 
-#### common for both modules
-
-`ARKINDEX_API_TOKEN` and `ARKINDEX_API_URL` environment variables must be defined.
-
-Install it as a package
+### Installation
+Install it as a package:
 ```bash
 virtualenv -p python3 .env
 source .env/bin/activate
 pip install -e .
 ```
 
-Use help to list possible parameters:
+### Environment variables
+`ARKINDEX_API_TOKEN` and `ARKINDEX_API_URL` environment variables must be defined.
+
+You can create an alias by adding this line to your `~/.bashrc`:
+```sh
+alias set_demo='export ARKINDEX_API_URL=https://demo.arkindex.org/;export ARKINDEX_API_TOKEN=my_api_token'
+```
+
+Then run:
+```sh
+source ~/.bashrc
+set_demo
+```
+
+### Arguments
+
+Use help to list possible parameters (or read [`kaldi_data_generator/arguments.py`](kaldi_data_generator/arguments.py))
 ```bash
 kaldi-data-generator --help
 ```
 
-There is also an option that skips all vertical transcriptions and it is `--skip_vertical_lines`
-#### Kaldi format
-Simple example:
+You can also set the arguments using a JSON or YAML configuration file:
+```yaml
+---
+format: kaldi
+dataset_name: balsac
+out_dir: my_balsac_kaldi
+common:
+  cache_dir: "/tmp/kaldi_data_generator_solene/cache/"
+  log_parameters: true
+image:
+  extraction_mode: deskew_min_area_rect
+  max_deskew_angle: 45
+split:
+  train_ratio: 0.8
+  test_ratio: 0.1
+select:
+  pages:
+  - 18c1d2d9-72e8-4f7a-a866-78b59dd407dd
+  - 901b9c27-1cbe-44ea-94a0-d9c783f17905
+  - db9dd27c-e96c-43c2-bf29-991212243453
+  - b87999e2-3733-43b1-b8ef-0a297f90bf0f
+  - 7fe3d786-068f-48c9-ae63-86db2f986c4c
+  - 4fc61e75-4a11-42e3-b317-348451629bda
+  - 3e7e37c2-d0cc-41b3-8d8c-6de0bbc69012
+  - 63b6e80b-a825-4068-a12a-d12e3edf5f80
+  - b11decff-1c07-4c51-a5be-401974ea55ea
+  - 735cdde6-e540-4dbd-b271-2206e2498156
+filter:
+  transcription_type: text_line
+```
+In this case, run:
+```sh
+kaldi-data-generator --config config.yaml
+```
+
+You can also set manually some parameters in the command line.
+```sh
+kaldi-data-generator --config config.yaml -o my_balsac_kraken -f kraken --image.extraction_mode polygon
+```
+
+Every run will export a `config.yaml` file and a `param.json` that can be used to reproduce the data generation.
+
+## Examples
+
+> :pencil: these corpus ids are from https://demo.arkindex.org/, use `set_demo`
+
+Two formats are currently supported: `kaldi` and `kraken`.
+### Kaldi format
+
+#### With page ids
 ```bash
-kaldi-data-generator -f kaldi --dataset_name my_balsac --out_dir /tmp/balsac/ --volumes 8f4005e9-1921-47b0-be7b-e27c7fd29486  d2f7c563-1622-4721-bd51-96fab97189f7
+kaldi-data-generator -f kaldi --dataset_name my_balsac --out_dir balsac --select.pages [18c1d2d9-72e8-4f7a-a866-78b59dd407dd,901b9c27-1cbe-44ea-94a0-d9c783f17905,db9dd27c-e96c-43c2-bf29-991212243453]
 ```
 
-With corpus ids
+#### With volumes ids
 ```bash
-kaldi-data-generator -f kaldi --dataset_name cz --out_dir /tmp/home_cz/ --corpora 1ed45e94-9108-4029-a529-9abe37f55ba0
+kaldi-data-generator -f kaldi --dataset_name my_balsac --out_dir balsac --select.volumes [1d5a26d8-6a3e-45ed-bbb6-5a33d09782aa,46a3426f-86d4-45f1-bd57-0de43cd63efd,85207944-2230-4b76-a98f-735a11506743]
 ```
 
-Polygon example:
+#### With corpus ids
 ```bash
-kaldi-data-generator -f kaldi --dataset_name my_balsac2 --extraction_mode polygon --out_dir /tmp/balsac/ --pages 50e1c3c0-2fe9-4216-805e-1a2fd2e7e9f4
+kaldi-data-generator -f kaldi --dataset_name my_balsac --out_dir balsac --select.corpora [135eb31f-2c33-4ae3-be4e-2ae9adfd7c75] --select.volume_type page
 ```
 
 The script creates 3 directories `Lines`, `Transcriptions`, `Partitions` in the specified `out_dir`.
@@ -43,15 +104,14 @@ The contents of these directories must be copied (or symlinked) to the correspon
 
 ### Kraken format
 
-simple examples:
-```
-$ kaldi-data-generator -f kraken -o <output_dir> --volumes <volume_id> --no_split
-```
-For instance to download the 4 sets from IAM (2 validation set on Arkindex) in 3 directories :
+Create a kraken database based on existing splits:
 ```
-$ kaldi-data-generator -f kraken -o iam_training --volumes e7a95479-e5fc-4b20-830c-0c6e38bf8f72 --no_split
-$ kaldi-data-generator -f kraken -o iam_validation --volumes edc78ee1-09e0-4671-806b-5fc0392707d9 --no_split
-$ kaldi-data-generator -f kraken -o iam_validation --volumes fefbbfca-a6dd-4e00-8797-0d4628cb024d --no_split
-$ kaldi-data-generator -f kraken -o iam_test --volumes 0ce2b631-01d7-49bf-b213-ceb6eae74a9b --no_split
+kaldi-data-generator -f kraken -o balsac_training --filter.volumes [e7a95479-e5fc-4b20-830c-0c6e38bf8f72] --split.no_split
+$ kaldi-data-generator -f kraken -o balsac_validation --filter.volumes [db9dd27c-e96c-43c2-bf29-991212243453] --split.no_split
+$ kaldi-data-generator -f kraken -o balsac_test --filter.volumes [901b9c27-1cbe-44ea-94a0-d9c783f17905] --split.no_split
 ```
 
+## TODO
+* Pylaia format
+* DAN format
+* Resize image (fixed height, fixed_width, rescale...)
\ No newline at end of file
diff --git a/kaldi_data_generator/arguments.py b/kaldi_data_generator/arguments.py
new file mode 100644
index 0000000..9c1d8a6
--- /dev/null
+++ b/kaldi_data_generator/arguments.py
@@ -0,0 +1,146 @@
+# -*- coding: utf-8 -*-
+import getpass
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import List, Optional
+
+USER = getpass.getuser()
+
+
+class Style(Enum):
+    handwritten: str = "handwritten"
+    typewritten: str = "typewritten"
+    other: str = "other"
+
+
+class ExtractionMode(Enum):
+    boundingRect: str = "boundingRect"
+    min_area_rect: str = "min_area_rect"
+    deskew_min_area_rect: str = "deskew_min_area_rect"
+    skew_min_area_rect: str = "skew_min_area_rect"
+    polygon: str = "polygon"
+    skew_polygon: str = "skew_polygon"
+    deskew_polygon: str = "deskew_polygon"
+
+
+class TranscriptionType(Enum):
+    word: str = "word"
+    text_line: str = "text_line"
+    half_subject_line: str = "half_subject_line"
+    text_zone: str = "text_zone"
+    paragraph: str = "paragraph"
+    act: str = "act"
+    page: str = "page"
+
+
+@dataclass
+class SelectArgs:
+    """
+    Arguments to select elements from Arkindex
+
+    Args:
+        corpora (list): List of corpus ids to be used.
+        volumes (list): List of volume ids to be used.
+        folders (list): List of folder ids to be used. Elements of `volume_type` will be searched recursively in these folders
+        pages (list): List of page ids to be used.
+        selection (bool): Get elements from selection
+        volume_type (str): Volumes (1 level above page) may have a different name on corpora
+    """
+
+    corpora: Optional[List[str]] = field(default_factory=list)
+    volumes: Optional[List[str]] = field(default_factory=list)
+    folders: Optional[List[str]] = field(default_factory=list)
+    pages: Optional[List[str]] = field(default_factory=list)
+    selection: bool = False
+    volume_type: str = "volume"
+
+
+@dataclass
+class CommonArgs:
+    """
+    General arguments
+
+    Args:
+        cache_dir (str): Cache directory where to save the full size downloaded images.
+        log_parameters (bool): Save every parameters to a JSON file.
+    """
+
+    cache_dir: str = f"/tmp/kaldi_data_generator_{USER}/cache/"
+    log_parameters: bool = True
+
+
+@dataclass
+class SplitArgs:
+    """
+    Arguments related to data splitting into training, validation and test subsets.
+
+    Args:
+        train_ratio (float): Ratio of data to be used in the training set. Should be between 0 and 1.
+        test_ratio (float): Ratio of data to be used in the testing set. Should be between 0 and 1.
+        val_ratio (float): Ratio of data to be used in the validation set. The sum of three variables should equal 1.
+        use_existing_split (bool): Use an existing split instead of random. Expecting line_ids to be prefixed with (train, val and test).
+        split_only (bool): Create the split from already downloaded lines, don't download the lines
+        no_split (bool): No splitting of the data to be done just download the line in the right format
+    """
+
+    train_ratio: float = 0.8
+    test_ratio: float = 0.1
+    val_ratio: float = 1 - train_ratio - test_ratio
+    use_existing_split: bool = False
+    split_only: bool = False
+    no_split: bool = False
+
+
+@dataclass
+class ImageArgs:
+    """
+    Arguments related to image transformation.
+
+    Args:
+        extraction_mode: Mode for extracting the line images: {[e.name for e in Extraction]},
+        max_deskew_angle: Maximum angle by which deskewing is allowed to rotate the line image.
+            the angle determined by deskew tool is bigger than max then that line won't be deskewed/rotated.
+        skew_angle: Angle by which the line image will be rotated. Useful for data augmnetation"
+            creating skewed text lines for a more robust model. Only used with skew_* extraction modes.
+        should_rotate (bool): Use text line rotation class to rotate lines if possible
+        grayscale (bool): Convert images to grayscale (By default grayscale)
+        scale_x (float): Ratio of how much to scale the polygon horizontally (1.0 means no rescaling)
+        scale_y_top (float): Ratio of how much to scale the polygon vertically on the top (1.0 means no rescaling)
+        scale_y_bottom (float): Ratio of how much to scale the polygon vertically on the bottom (1.0 means no rescaling)
+    """
+
+    extraction_mode: ExtractionMode = ExtractionMode.deskew_min_area_rect
+    max_deskew_angle: int = 45
+    skew_angle: int = 0
+    should_rotate: bool = False
+    grayscale: bool = True
+    scale_x: Optional[float] = None
+    scale_y_top: Optional[float] = None
+    scale_y_bottom: Optional[float] = None
+
+
+@dataclass
+class FilterArgs:
+    """
+    Arguments related to element filtering.
+
+    Args:
+        transcription_type: Which type of elements' transcriptions to use? (page, paragraph, text_line, etc)
+        ignored_classes: List of ignored ml_class names. Filter lines by class
+        accepted_classes: List of accepted ml_class names. Filter lines by class
+        accepted_worker_version_ids:List of accepted worker version ids. Filter transcriptions by worker version ids.
+            The order is important - only up to one transcription will be chosen per element (text_line)
+            and the worker version order defines the precedence. If there exists a transcription for
+            the first worker version then it will be chosen, otherwise will continue on to the next
+            worker version.
+            Use `--accepted_worker_version_ids manual` to get only manual transcriptions
+        style: Filter line images by style class. 'other' corresponds to line elements that have neither
+            handwritten or typewritten class : {[s.name for s in Style]}
+    """
+
+    transcription_type: TranscriptionType = TranscriptionType.text_line
+    ignored_classes: List[str] = field(default_factory=list)
+    accepted_classes: List[str] = field(default_factory=list)
+    accepted_worker_version_ids: List[str] = field(default_factory=list)
+    skip_vertical_lines: bool = False
+    style: Style = None
diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py
index 3ae9200..538e720 100644
--- a/kaldi_data_generator/main.py
+++ b/kaldi_data_generator/main.py
@@ -1,8 +1,6 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
-
 import argparse
-import getpass
 import os
 import random
 from collections import Counter, defaultdict
@@ -18,10 +16,20 @@ from arkindex import options_from_env
 from line_image_extractor.extractor import extract, read_img, save_img
 from line_image_extractor.image_utils import WHITE, Extraction, rotate_and_trim
 
+import jsonargparse
+from kaldi_data_generator.arguments import (
+    CommonArgs,
+    FilterArgs,
+    ImageArgs,
+    SelectArgs,
+    SplitArgs,
+    Style,
+)
 from kaldi_data_generator.image_utils import download_image, resize_transcription_data
 from kaldi_data_generator.utils import (
     CachedApiClient,
     TranscriptionData,
+    export_parameters,
     logger,
     write_file,
 )
@@ -31,6 +39,7 @@ random.seed(SEED)
 MANUAL = "manual"
 TEXT_LINE = "text_line"
 DEFAULT_RESCALE = 1.0
+STYLE_CLASSES = [el.value for el in Style]
 
 ROTATION_CLASSES_TO_ANGLES = {
     "rotate_0": 0,
@@ -46,67 +55,46 @@ def create_api_client(cache_dir=None):
     return CachedApiClient(cache_root=cache_dir, **options_from_env())
 
 
-class Style(Enum):
-    handwritten: str = "handwritten"
-    typewritten: str = "typewritten"
-    other: str = "other"
-
-
-STYLE_CLASSES = [s.name for s in [Style.handwritten, Style.typewritten]]
-
-
 class HTRDataGenerator:
     def __init__(
         self,
         format,
-        dataset_name="foo",
-        out_dir_base="/tmp/kaldi_data",
-        grayscale=True,
-        extraction=Extraction.boundingRect,
-        accepted_classes=None,
-        ignored_classes=None,
-        style=None,
-        skip_vertical_lines=False,
-        accepted_worker_version_ids=None,
-        transcription_type=TEXT_LINE,
-        max_deskew_angle=45,
-        skew_angle=0,
-        should_rotate=False,
-        scale_x=None,
-        scale_y_top=None,
-        scale_y_bottom=None,
-        cache_dir=None,
+        dataset_name="my_dataset",
+        out_dir_base="data",
+        image=ImageArgs(),
+        common=CommonArgs(),
+        filter=FilterArgs(),
         api_client=None,
     ):
 
         self.format = format
         self.out_dir_base = out_dir_base
         self.dataset_name = dataset_name
-        self.grayscale = grayscale
-        self.extraction_mode = extraction
-        self.accepted_classes = accepted_classes
-        self.ignored_classes = ignored_classes
+        self.grayscale = image.grayscale
+        self.extraction_mode = Extraction[image.extraction_mode.value]
+        self.accepted_classes = filter.accepted_classes
+        self.ignored_classes = filter.ignored_classes
         self.should_filter_by_class = bool(self.accepted_classes) or bool(
             self.ignored_classes
         )
-        self.accepted_worker_version_ids = accepted_worker_version_ids
+        self.accepted_worker_version_ids = filter.accepted_worker_version_ids
         self.should_filter_by_worker = bool(self.accepted_worker_version_ids)
-        self.style = style
+        self.style = filter.style
         self.should_filter_by_style = bool(self.style)
-        self.transcription_type = transcription_type
-        self.skip_vertical_lines = skip_vertical_lines
+        self.transcription_type = filter.transcription_type.value
+        self.skip_vertical_lines = filter.skip_vertical_lines
         self.skipped_pages_count = 0
         self.skipped_vertical_lines_count = 0
         self.accepted_lines_count = 0
-        self.max_deskew_angle = max_deskew_angle
-        self.skew_angle = skew_angle
-        self.should_rotate = should_rotate
-        if scale_x or scale_y_top or scale_y_bottom:
+        self.max_deskew_angle = image.max_deskew_angle
+        self.skew_angle = image.skew_angle
+        self.should_rotate = image.should_rotate
+        if image.scale_x or image.scale_y_top or image.scale_y_bottom:
             self.should_resize_polygons = True
             # use 1.0 as default - no resize, if not specified
-            self.scale_x = scale_x or DEFAULT_RESCALE
-            self.scale_y_top = scale_y_top or DEFAULT_RESCALE
-            self.scale_y_bottom = scale_y_bottom or DEFAULT_RESCALE
+            self.scale_x = image.scale_x or DEFAULT_RESCALE
+            self.scale_y_top = image.scale_y_top or DEFAULT_RESCALE
+            self.scale_y_bottom = image.scale_y_bottom or DEFAULT_RESCALE
         else:
             self.should_resize_polygons = False
         self.api_client = api_client
@@ -129,7 +117,7 @@ class HTRDataGenerator:
             )
             os.makedirs(self.out_line_img_dir, exist_ok=True)
 
-        self.cache_dir = cache_dir
+        self.cache_dir = Path(common.cache_dir)
         logger.info(f"Setting up cache to {self.cache_dir}")
         self.img_cache_dir = self.cache_dir / "images"
         self.img_cache_dir.mkdir(exist_ok=True, parents=True)
@@ -486,17 +474,24 @@ class HTRDataGenerator:
                     trans.text,
                 )
 
-    def run_selection(self):
+    def run_selection(self, select):
+        """
+        Update select to keep track of selected ids.
+        """
         selected_elems = [e for e in self.api_client.paginate("ListSelection")]
         for elem_type, elems_of_type in groupby(
             selected_elems, key=lambda x: x["type"]
         ):
+            elements_ids = [el["id"] for el in elems_of_type]
             if elem_type == "page":
-                self.run_pages(list(elems_of_type))
-            elif elem_type in ["volume", "folder"]:
-                self.run_volumes(list(elems_of_type))
+                select.pages += elements_ids
+            elif elem_type == "volume":
+                select.volumes += elements_ids
+            elif elem_type == "folder":
+                select.folders += elements_ids
             else:
                 raise ValueError(f"Unsupported element type {elem_type} in selection!")
+        return select
 
     def run_pages(self, pages: list):
         if all(isinstance(n, str) for n in pages):
@@ -548,15 +543,9 @@ class HTRDataGenerator:
 
 
 class Split(Enum):
-    Train: int = 0
-    Test: int = 1
-    Validation: int = 2
-
-    @property
-    def short_name(self) -> str:
-        if self == self.Validation:
-            return "val"
-        return self.name.lower()
+    Train: str = "train"
+    Test: str = "test"
+    Validation: str = "val"
 
 
 class KaldiPartitionSplitter:
@@ -564,44 +553,56 @@ class KaldiPartitionSplitter:
         self,
         out_dir_base="/tmp/kaldi_data",
         split_train_ratio=0.8,
+        split_val_ratio=0.1,
         split_test_ratio=0.1,
         use_existing_split=False,
     ):
         self.out_dir_base = out_dir_base
         self.split_train_ratio = split_train_ratio
         self.split_test_ratio = split_test_ratio
-        self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
+        self.split_val_ratio = split_val_ratio
         self.use_existing_split = use_existing_split
 
     def page_level_split(self, line_ids: list) -> dict:
-        # need to sort again, because `set` will lose the order
-        page_ids = sorted({"_".join(line_id.split("_")[:-2]) for line_id in line_ids})
+        """
+        Split pages into train, validation and test subsets.
+        Don't split lines to avoid data leakage.
+            line_ids (list): a list of line ids named {page_id}_{line_number}_{line_id}
+        """
+        # Get page ids from line ids to create splits at page level
+        page_ids = ["_".join(line_id.split("_")[:-2]) for line_id in line_ids]
+        # Remove duplicates and sort for reproducibility
+        page_ids = sorted(set(page_ids))
         random.Random(SEED).shuffle(page_ids)
         page_count = len(page_ids)
 
-        train_page_ids = page_ids[: round(page_count * self.split_train_ratio)]
-        page_ids = page_ids[round(page_count * self.split_train_ratio) :]
-
-        test_page_ids = page_ids[: round(page_count * self.split_test_ratio)]
-        page_ids = page_ids[round(page_count * self.split_test_ratio) :]
-
-        val_page_ids = page_ids
+        # Use np.split to split in three sets
+        stop_train_idx = round(page_count * self.split_train_ratio)
+        stop_val_idx = stop_train_idx + round(page_count * self.split_val_ratio)
+        train_page_ids, val_page_ids, test_page_ids = np.split(
+            page_ids, [stop_train_idx, stop_val_idx]
+        )
 
+        # Build dictionary that will be used to split lines {id: split}
         page_dict = {page_id: Split.Train.value for page_id in train_page_ids}
-        page_dict.update({page_id: Split.Test.value for page_id in test_page_ids})
         page_dict.update({page_id: Split.Validation.value for page_id in val_page_ids})
+        page_dict.update({page_id: Split.Test.value for page_id in test_page_ids})
         return page_dict
 
     def existing_split(self, line_ids: list) -> list:
-        split_dict = {split.short_name: [] for split in Split}
+        """
+        Expect line_ids to be named {split}/{path_to_image} where split in ["train", "val", "test"]
+        """
+        split_dict = {split: [] for split in Split}
         for line_id in line_ids:
             split_prefix = line_id.split("/")[0].lower()
             split_dict[split_prefix].append(line_id)
-        splits = [split_dict[split.short_name] for split in Split]
-        return splits
+        return split_dict
 
     def create_partitions(self):
-        logger.info("Creating partitions")
+        """ """
+        logger.info(f"Creating {[split.value for split in Split]} partitions")
+        # Get all images ids (and remove extension)
         lines_path = Path(f"{self.out_dir_base}/Lines")
         line_ids = [
             str(file.relative_to(lines_path).with_suffix(""))
@@ -613,7 +614,8 @@ class KaldiPartitionSplitter:
             datasets = self.existing_split(line_ids)
         else:
             page_dict = self.page_level_split(line_ids)
-            datasets = [[] for _ in range(3)]
+            # extend this split for lines
+            datasets = {s.value: [] for s in Split}
             for line_id in line_ids:
                 page_id = "_".join(line_id.split("_")[:-2])
                 split_id = page_dict[page_id]
@@ -621,315 +623,157 @@ class KaldiPartitionSplitter:
 
         partitions_dir = os.path.join(self.out_dir_base, "Partitions")
         os.makedirs(partitions_dir, exist_ok=True)
-        for i, dataset in enumerate(datasets):
-            if not dataset:
-                logger.info(f"Partition {Split(i).name} is empty! Skipping..")
+        for split, split_line_ids in datasets.items():
+            if not split_line_ids:
+                logger.info(f"Partition {split} is empty! Skipping...")
                 continue
-            file_name = f"{partitions_dir}/{Split(i).name}Lines.lst"
-            write_file(file_name, "\n".join(dataset) + "\n")
+            file_name = f"{partitions_dir}/{Split(split).name}Lines.lst"
+            write_file(file_name, "\n".join(split_line_ids) + "\n")
+        return datasets
 
 
-def create_parser():
-    user_name = getpass.getuser()
+def run(format, dataset_name, out_dir, common, image, split, select, filter):
 
-    parser = argparse.ArgumentParser(
-        description="Script to generate Kaldi or kraken training data from annotations from Arkindex",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-    parser.add_argument(
-        "-f",
-        "--format",
-        type=str,
-        help="is the data generated going to be used for kaldi or kraken",
-    )
-    parser.add_argument(
-        "-n",
-        "--dataset_name",
-        type=str,
-        help="Name of the dataset being created for kaldi or kraken "
-        "(useful for distinguishing different datasets when in Lines or Transcriptions directory)",
-    )
-    parser.add_argument(
-        "-o", "--out_dir", type=str, required=True, help="output directory"
-    )
-    parser.add_argument(
-        "--train_ratio",
-        type=float,
-        default=0.8,
-        help="Ratio of pages to be used in train (between 0 and 1)",
-    )
-    parser.add_argument(
-        "--test_ratio",
-        type=float,
-        default=0.1,
-        help="Ratio of pages to be used in test (between 0 and 1 - train_ratio)",
-    )
-    parser.add_argument(
-        "--use_existing_split",
-        action="store_true",
-        default=False,
-        help="Use an existing split instead of random. "
-        "Expecting line_ids to be prefixed with (train, val and test)",
-    )
-    parser.add_argument(
-        "--split_only",
-        "--no_download",
-        action="store_true",
-        default=False,
-        help="Create the split from already downloaded lines, don't download the lines",
-    )
-    parser.add_argument(
-        "--no_split",
-        action="store_true",
-        default=False,
-        help="No splitting of the data to be done just download the line in the right format",
-    )
+    api_client = create_api_client(Path(common.cache_dir))
 
-    parser.add_argument(
-        "-e",
-        "--extraction_mode",
-        type=lambda x: Extraction[x],
-        default=Extraction.boundingRect,
-        help=f"Mode for extracting the line images: {[e.name for e in Extraction]}",
-    )
-
-    parser.add_argument(
-        "--max_deskew_angle",
-        type=int,
-        default=45,
-        help="Maximum angle by which deskewing is allowed to rotate the line image. "
-        "If the angle determined by deskew tool is bigger than max "
-        "then that line won't be deskewed/rotated.",
-    )
-
-    parser.add_argument(
-        "--skew_angle",
-        type=int,
-        default=0,
-        help="Angle by which the line image will be rotated. Useful for data augmnetation"
-        " - creating skewed text lines for a more robust model."
-        " Only used with skew_* extraction modes.",
-    )
+    if not split.split_only:
+        data_generator = HTRDataGenerator(
+            format=format,
+            dataset_name=dataset_name,
+            out_dir_base=out_dir,
+            common=common,
+            image=image,
+            filter=filter,
+            api_client=api_client,
+        )
 
-    parser.add_argument(
-        "--should_rotate",
-        action="store_true",
-        default=False,
-        help="Use text line rotation class to rotate lines if possible",
-    )
+        # extract all the lines and transcriptions
+        if select.selection:
+            select = data_generator.run_selection(select)
+        if select.pages:
+            data_generator.run_pages(select.pages)
+        if select.volumes:
+            data_generator.run_volumes(select.volumes)
+        if select.folders:
+            data_generator.run_folders(select.folders, select.volume_type)
+        if select.corpora:
+            data_generator.run_corpora(select.corpora, select.volume_type)
+        if data_generator.skipped_vertical_lines_count > 0:
+            logger.info(
+                f"Number of skipped pages: {data_generator.skipped_pages_count}"
+            )
+            _skipped_vertical_count = data_generator.skipped_vertical_lines_count
+            _total_count = _skipped_vertical_count + data_generator.accepted_lines_count
+            skipped_ratio = _skipped_vertical_count / _total_count * 100
 
-    parser.add_argument(
-        "--transcription_type",
-        type=str,
-        default="text_line",
-        help="Which type of elements' transcriptions to use? (page, paragraph, text_line, etc)",
-    )
+            logger.info(
+                f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({round(skipped_ratio, 2)}%)"
+            )
+    else:
+        logger.info("Creating a split from already downloaded files")
+        data_generator = None
 
-    group = parser.add_mutually_exclusive_group(required=False)
-    group.add_argument(
-        "--grayscale",
-        action="store_true",
-        dest="grayscale",
-        help="Convert images to grayscale (By default grayscale)",
-    )
-    group.add_argument(
-        "--color", action="store_false", dest="grayscale", help="Use color images"
-    )
-    group.set_defaults(grayscale=True)
+    if not split.no_split:
+        kaldi_partitioner = KaldiPartitionSplitter(
+            out_dir_base=out_dir,
+            split_train_ratio=split.train_ratio,
+            split_val_ratio=split.val_ratio,
+            split_test_ratio=split.test_ratio,
+            use_existing_split=split.use_existing_split,
+        )
 
-    parser.add_argument(
-        "--corpora",
-        nargs="*",
-        help="List of corpus ids to be used, separated by spaces",
-    )
-    parser.add_argument(
-        "--folders",
-        type=str,
-        nargs="*",
-        help="List of folder ids to be used, separated by spaces. "
-        "Elements of `volume_type` will be searched recursively in these folders",
-    )
-    parser.add_argument(
-        "--volumes",
-        nargs="*",
-        help="List of volume ids to be used, separated by spaces",
-    )
-    parser.add_argument(
-        "--pages", nargs="*", help="List of page ids to be used, separated by spaces"
-    )
-    parser.add_argument(
-        "-v",
-        "--volume_type",
-        type=str,
-        default="volume",
-        help="Volumes (1 level above page) may have a different name on corpora",
-    )
-    parser.add_argument(
-        "--selection",
-        action="store_true",
-        default=False,
-        help="Get elements from selection",
-    )
+        # create partitions from all the extracted data
+        datasets = kaldi_partitioner.create_partitions()
+    else:
+        logger.info("No split to be done")
+        datasets = {}
 
-    parser.add_argument(
-        "--skip_vertical_lines",
-        action="store_true",
-        default=False,
-        help="skips vertical lines when downloading",
-    )
+    logger.info("DONE")
 
-    parser.add_argument(
-        "--ignored_classes",
-        nargs="*",
-        default=[],
-        help="List of ignored ml_class names. Filter lines by class",
+    export_parameters(
+        format,
+        dataset_name,
+        out_dir,
+        common,
+        image,
+        split,
+        select,
+        filter,
+        datasets,
+        arkindex_api_url=options_from_env()["base_url"],
     )
 
-    parser.add_argument(
-        "--accepted_classes",
-        nargs="*",
-        default=[],
-        help="List of accepted ml_class names. Filter lines by class",
+    logger.warning(
+        f"Consider cleaning your cache directory {common.cache_dir} if you are done."
     )
 
-    parser.add_argument(
-        "--accepted_worker_version_ids",
-        nargs="*",
-        default=[],
-        help="List of accepted worker version ids. Filter transcriptions by worker version ids."
-        "The order is important - only up to one transcription will be chosen per element (text_line)"
-        " and the worker version order defines the precedence. If there exists a transcription for"
-        " the first worker version then it will be chosen, otherwise will continue on to the next"
-        " worker version."
-        " Use `--accepted_worker_version_ids manual` to get only manual transcriptions",
-    )
 
-    parser.add_argument(
-        "--style",
-        type=lambda x: Style[x.lower()],
-        default=None,
-        help=f"Filter line images by style class. 'other' corresponds to line elements that "
-        f"have neither handwritten or typewritten class : {[s.name for s in Style]}",
+def get_args():
+    parser = jsonargparse.ArgumentParser(
+        description="Script to generate Kaldi or kraken training data from annotations from Arkindex",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+        parse_as_dict=True,
     )
-
     parser.add_argument(
-        "--scale_x",
-        type=float,
-        default=None,
-        help="Ratio of how much to scale the polygon horizontally (1.0 means no rescaling)",
+        "--config", action=jsonargparse.ActionConfigFile, help="Configuration file"
     )
     parser.add_argument(
-        "--scale_y_top",
-        type=float,
-        default=None,
-        help="Ratio of how much to scale the polygon vertically on the top (1.0 means no rescaling)",
+        "-f",
+        "--format",
+        type=str,
+        required=True,
+        help="is the data generated going to be used for kaldi or kraken",
     )
-
     parser.add_argument(
-        "--scale_y_bottom",
-        type=float,
-        default=None,
-        help="Ratio of how much to scale the polygon vertically on the bottom (1.0 means no rescaling)",
+        "-d",
+        "--dataset_name",
+        type=str,
+        required=True,
+        help="Name of the dataset being created for kaldi or kraken "
+        "(useful for distinguishing different datasets when in Lines or Transcriptions directory)",
     )
-
     parser.add_argument(
-        "--cache_dir",
-        type=Path,
-        default=Path(f"/tmp/kaldi_data_generator_{user_name}/cache/"),
-        help="Cache dir where to save the full size downloaded images. Change it to force redownload.",
+        "-o", "--out_dir", type=str, required=True, help="output directory"
     )
 
-    return parser
-
-
-def main():
-    parser = create_parser()
-    args = parser.parse_args()
-
-    if not args.dataset_name and not args.split_only and not args.format == "kraken":
-        parser.error("--dataset_name must be specified (unless --split-only)")
-
-    if args.accepted_classes and args.ignored_classes:
-        if set(args.accepted_classes) & set(args.ignored_classes):
+    parser.add_class_arguments(CommonArgs, "common")
+    parser.add_class_arguments(ImageArgs, "image")
+    parser.add_class_arguments(SplitArgs, "split")
+    parser.add_class_arguments(FilterArgs, "filter")
+    parser.add_class_arguments(SelectArgs, "select")
+
+    args = parser.parse_args(with_meta=False)
+    args["common"] = CommonArgs(**args["common"])
+    args["image"] = ImageArgs(**args["image"])
+    args["split"] = SplitArgs(**args["split"])
+    args["select"] = SelectArgs(**args["select"])
+    args["filter"] = FilterArgs(**args["filter"])
+
+    # Check overlap of accepted and ignored classes
+    accepted_classes = args["filter"].accepted_classes
+    ignored_classes = args["filter"].accepted_classes
+    if accepted_classes and ignored_classes:
+        if set(accepted_classes) & set(ignored_classes):
             parser.error(
-                f"--accepted_classes and --ignored_classes values must not overlap ({args.accepted_classes} - {args.ignored_classes})"
+                f"--filter.accepted_classes and --filter.ignored_classes values must not overlap ({accepted_classes} - {ignored_classes})"
             )
 
-    if args.style and (args.accepted_classes or args.ignored_classes):
-        if set(STYLE_CLASSES) & (
-            set(args.accepted_classes) | set(args.ignored_classes)
-        ):
+    if args["filter"].style and (accepted_classes or ignored_classes):
+        if set(STYLE_CLASSES) & (set(accepted_classes) | set(ignored_classes)):
             parser.error(
                 f"--style class values ({STYLE_CLASSES}) shouldn't be in the accepted_classes list "
                 f"(or ignored_classes list) "
-                "if both --style and --accepted_classes (or --ignored_classes) are used together."
-            )
-
-    logger.info(f"ARGS {args} \n")
-
-    api_client = create_api_client(args.cache_dir)
-
-    if not args.split_only:
-        data_generator = HTRDataGenerator(
-            format=args.format,
-            dataset_name=args.dataset_name,
-            out_dir_base=args.out_dir,
-            grayscale=args.grayscale,
-            extraction=args.extraction_mode,
-            accepted_classes=args.accepted_classes,
-            ignored_classes=args.ignored_classes,
-            style=args.style,
-            skip_vertical_lines=args.skip_vertical_lines,
-            transcription_type=args.transcription_type,
-            accepted_worker_version_ids=args.accepted_worker_version_ids,
-            max_deskew_angle=args.max_deskew_angle,
-            skew_angle=args.skew_angle,
-            should_rotate=args.should_rotate,
-            scale_x=args.scale_x,
-            scale_y_top=args.scale_y_top,
-            scale_y_bottom=args.scale_y_bottom,
-            cache_dir=args.cache_dir,
-            api_client=api_client,
-        )
-
-        # extract all the lines and transcriptions
-        if args.selection:
-            data_generator.run_selection()
-        if args.pages:
-            data_generator.run_pages(args.pages)
-        if args.volumes:
-            data_generator.run_volumes(args.volumes)
-        if args.folders:
-            data_generator.run_folders(args.folders, args.volume_type)
-        if args.corpora:
-            data_generator.run_corpora(args.corpora, args.volume_type)
-        if data_generator.skipped_vertical_lines_count > 0:
-            logger.info(
-                f"Number of skipped pages: {data_generator.skipped_pages_count}"
+                "if --filter.style is used with either --filter.accepted_classes or --filter.ignored_classes."
             )
-            _skipped_vertical_count = data_generator.skipped_vertical_lines_count
-            _total_count = _skipped_vertical_count + data_generator.accepted_lines_count
-            skipped_ratio = _skipped_vertical_count / _total_count * 100
 
-            logger.info(
-                f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({round(skipped_ratio, 2)}%)"
-            )
-    else:
-        logger.info("Creating a split from already downloaded files")
-    if not args.no_split:
-        kaldi_partitioner = KaldiPartitionSplitter(
-            out_dir_base=args.out_dir,
-            split_train_ratio=args.train_ratio,
-            split_test_ratio=args.test_ratio,
-            use_existing_split=args.use_existing_split,
-        )
+    del args["config"]
+    return args
 
-        # create partitions from all the extracted data
-        kaldi_partitioner.create_partitions()
-    else:
-        logger.info("No split to be done")
 
-    logger.info("DONE")
+def main():
+    args = get_args()
+    logger.info(f"Arguments: {args} \n")
+    run(**args)
 
 
 if __name__ == "__main__":
diff --git a/kaldi_data_generator/utils.py b/kaldi_data_generator/utils.py
index 0aed87a..59e0bbe 100644
--- a/kaldi_data_generator/utils.py
+++ b/kaldi_data_generator/utils.py
@@ -1,8 +1,11 @@
 # -*- coding: utf-8 -*-
+import getpass
 import json
 import logging
 import os
+import socket
 import sys
+from datetime import datetime
 from pathlib import Path
 
 import cv2
@@ -10,6 +13,8 @@ import numpy as np
 from arkindex import ArkindexClient
 from line_image_extractor.image_utils import BoundingBox
 
+import yaml
+
 logging.basicConfig(
     level=logging.INFO, format="%(asctime)s %(levelname)s/%(name)s: %(message)s"
 )
@@ -132,3 +137,74 @@ class CachedApiClient(ArkindexClient):
         request_cache.parent.mkdir(parents=True, exist_ok=True)
         request_cache.write_text(json.dumps(results))
         logger.info("Saved")
+
+
+def write_json(d, filename):
+    with open(filename, "w") as f:
+        f.write(json.dumps(d, indent=4))
+
+
+def write_yaml(d, filename):
+    with open(filename, "w") as f:
+        yaml.dump(d, f)
+
+
+def export_parameters(
+    format,
+    dataset_name,
+    out_dir,
+    common,
+    image,
+    split,
+    select,
+    filter,
+    datasets,
+    arkindex_api_url,
+):
+    """
+    Dump a JSON log file to keep track of parameters for this dataset
+    """
+    # Get config dict
+    config = {
+        "format": format,
+        "dataset_name": dataset_name,
+        "out_dir": out_dir,
+        "common": vars(common),
+        "image": vars(image),
+        "split": vars(split),
+        "select": vars(select),
+        "filter": vars(filter),
+    }
+    # Serialize special classes
+    image.extraction_mode = image.extraction_mode.name
+    filter.transcription_type = filter.transcription_type.value
+    common.cache_dir = str(common.cache_dir)
+
+    if common.log_parameters:
+        # Get additional info on dataset and user
+        current_datetime = datetime.now().strftime("%Y-%m-%d|%H-%M-%S")
+
+        parameters = {}
+        parameters["config"] = config
+        parameters["dataset"] = (
+            datasets
+            if isinstance(datasets, dict)
+            else {"train": datasets[0], "valid": datasets[1], "test": datasets[2]}
+        )
+        parameters["info"] = {
+            "user": getpass.getuser(),
+            "date": current_datetime,
+            "device": socket.gethostname(),
+            "arkindex_api_url": arkindex_api_url,
+        }
+
+        # Export
+        parameter_file = os.path.join(
+            out_dir, f"param-{dataset_name}-{format}-{current_datetime}.json"
+        )
+        write_json(parameters, parameter_file)
+        logger.info(f"Parameters exported in file {parameter_file}")
+
+    config_file = os.path.join(out_dir, "config.yaml")
+    write_yaml(config, config_file)
+    logger.info(f"Config exported in file {config_file}")
diff --git a/requirements.txt b/requirements.txt
index f46a80b..cca4484 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
 apistar==0.7.2
 arkindex-client==1.0.9
+jsonargparse
 teklia-line-image-extractor==0.2.4
 tqdm==4.64.0
-
 typesystem==0.2.5
diff --git a/tests/data/pinned_insects/partitions/TestLines.lst b/tests/data/pinned_insects/partitions/TestLines.lst
index 1ce3c45..4de4011 100644
--- a/tests/data/pinned_insects/partitions/TestLines.lst
+++ b/tests/data/pinned_insects/partitions/TestLines.lst
@@ -1,6 +1,2 @@
-testing/5bcb0a49-4810-4107-9b7f-62c30f913399_005_74e593f4-e5d1-4891-ba99-e9eabc1c9a9b
-testing/5bcb0a49-4810-4107-9b7f-62c30f913399_006_51ae11a8-3fb5-47f7-80d6-48b2b235076b
-testing/5bcb0a49-4810-4107-9b7f-62c30f913399_007_8d35675d-9f43-4b69-9ad7-e319dfb3be6e
-testing/5bcb0a49-4810-4107-9b7f-62c30f913399_008_b4b16c14-ce22-4f13-88f8-7c8e4e32071f
-testing/5bcb0a49-4810-4107-9b7f-62c30f913399_009_4ab037a2-451b-40b9-b598-1d14984dd8fd
-testing/5bcb0a49-4810-4107-9b7f-62c30f913399_010_52260fd8-cd63-463e-9f74-3ed3a97d14c5
\ No newline at end of file
+testing/6139d79e-be7b-46fd-9cde-20268bfb5114_005_3750213d-2db3-4425-9b6c-0e5c0f45596f
+testing/6139d79e-be7b-46fd-9cde-20268bfb5114_006_974c3bd8-d42a-4316-a021-345d6b11379a
\ No newline at end of file
diff --git a/tests/data/pinned_insects/partitions/ValidationLines.lst b/tests/data/pinned_insects/partitions/ValidationLines.lst
index 4de4011..1ce3c45 100644
--- a/tests/data/pinned_insects/partitions/ValidationLines.lst
+++ b/tests/data/pinned_insects/partitions/ValidationLines.lst
@@ -1,2 +1,6 @@
-testing/6139d79e-be7b-46fd-9cde-20268bfb5114_005_3750213d-2db3-4425-9b6c-0e5c0f45596f
-testing/6139d79e-be7b-46fd-9cde-20268bfb5114_006_974c3bd8-d42a-4316-a021-345d6b11379a
\ No newline at end of file
+testing/5bcb0a49-4810-4107-9b7f-62c30f913399_005_74e593f4-e5d1-4891-ba99-e9eabc1c9a9b
+testing/5bcb0a49-4810-4107-9b7f-62c30f913399_006_51ae11a8-3fb5-47f7-80d6-48b2b235076b
+testing/5bcb0a49-4810-4107-9b7f-62c30f913399_007_8d35675d-9f43-4b69-9ad7-e319dfb3be6e
+testing/5bcb0a49-4810-4107-9b7f-62c30f913399_008_b4b16c14-ce22-4f13-88f8-7c8e4e32071f
+testing/5bcb0a49-4810-4107-9b7f-62c30f913399_009_4ab037a2-451b-40b9-b598-1d14984dd8fd
+testing/5bcb0a49-4810-4107-9b7f-62c30f913399_010_52260fd8-cd63-463e-9f74-3ed3a97d14c5
\ No newline at end of file
diff --git a/tests/test_main.py b/tests/test_main.py
index 9273975..853239a 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -4,15 +4,12 @@ from pathlib import Path
 
 import pytest
 
+from kaldi_data_generator.arguments import FilterArgs
 from kaldi_data_generator.main import MANUAL, HTRDataGenerator, KaldiPartitionSplitter
 
 
 def test_init():
-
-    htr_data_gen = HTRDataGenerator(
-        format="kaldi", accepted_worker_version_ids=[], cache_dir=Path("/tmp/foo/bar")
-    )
-
+    htr_data_gen = HTRDataGenerator(format="kaldi")
     assert htr_data_gen is not None
 
 
@@ -38,16 +35,14 @@ def test_run_volumes_with_worker_version(
     out_dir_base = tmpdir
     htr_data_gen = HTRDataGenerator(
         format="kaldi",
-        accepted_worker_version_ids=worker_version_ids,
+        filter=FilterArgs(accepted_worker_version_ids=worker_version_ids),
         out_dir_base=out_dir_base,
-        cache_dir=Path("/tmp/foo/bar"),
     )
     htr_data_gen.api_client = api_client
 
     htr_data_gen.get_image = mocker.MagicMock()
     # return same fake image for all the pages
     htr_data_gen.get_image.return_value = fake_image
-
     htr_data_gen.run_volumes([fake_volume_id])
 
     trans_files = list(Path(htr_data_gen.out_line_text_dir).glob("*.txt"))
@@ -65,13 +60,12 @@ def test_run_volumes_with_worker_version(
 
 def test_create_partitions(fake_expected_partitions, tmpdir):
     out_dir_base = Path(tmpdir)
-    splitter = KaldiPartitionSplitter(
-        out_dir_base=out_dir_base,
-    )
+    splitter = KaldiPartitionSplitter(out_dir_base=out_dir_base)
 
     all_ids = []
     for split_name, expected_split_ids in fake_expected_partitions.items():
         all_ids += expected_split_ids
+
     # shuffle to show that splits are reproducible
     random.shuffle(all_ids)
 
-- 
GitLab