From cca26d5e67448b3035bca70e0eff3b713e78dc2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Fri, 22 Jul 2022 11:24:06 +0000 Subject: [PATCH] Export parameters json --- README.md | 112 +++- kaldi_data_generator/arguments.py | 146 +++++ kaldi_data_generator/main.py | 536 +++++++----------- kaldi_data_generator/utils.py | 76 +++ requirements.txt | 2 +- .../pinned_insects/partitions/TestLines.lst | 8 +- .../partitions/ValidationLines.lst | 8 +- tests/test_main.py | 16 +- 8 files changed, 512 insertions(+), 392 deletions(-) create mode 100644 kaldi_data_generator/arguments.py diff --git a/README.md b/README.md index d70bf7a..53c20b8 100644 --- a/README.md +++ b/README.md @@ -1,41 +1,102 @@ -### Kaldi and kraken training data generator +# Kaldi and kraken training data generator This script downloads pages with transcriptions from Arkindex -and converts data to Kaldi format or kraken format. It also generates train, val and test splits. +and converts data to Kaldi format or kraken format. +It also generates reproducible train, val and test splits. -### Using the script +## Usage -#### common for both modules - -`ARKINDEX_API_TOKEN` and `ARKINDEX_API_URL` environment variables must be defined. - -Install it as a package +### Installation +Install it as a package: ```bash virtualenv -p python3 .env source .env/bin/activate pip install -e . ``` -Use help to list possible parameters: +### Environment variables +`ARKINDEX_API_TOKEN` and `ARKINDEX_API_URL` environment variables must be defined. + +You can create an alias by adding this line to your `~/.bashrc`: +```sh +alias set_demo='export ARKINDEX_API_URL=https://demo.arkindex.org/;export ARKINDEX_API_TOKEN=my_api_token' +``` + +Then run: +```sh +source ~/.bashrc +set_demo +``` + +### Arguments + +Use help to list possible parameters (or read [`kaldi_data_generator/arguments.py`](kaldi_data_generator/arguments.py)) ```bash kaldi-data-generator --help ``` -There is also an option that skips all vertical transcriptions and it is `--skip_vertical_lines` -#### Kaldi format -Simple example: +You can also set the arguments using a JSON or YAML configuration file: +```yaml +--- +format: kaldi +dataset_name: balsac +out_dir: my_balsac_kaldi +common: + cache_dir: "/tmp/kaldi_data_generator_solene/cache/" + log_parameters: true +image: + extraction_mode: deskew_min_area_rect + max_deskew_angle: 45 +split: + train_ratio: 0.8 + test_ratio: 0.1 +select: + pages: + - 18c1d2d9-72e8-4f7a-a866-78b59dd407dd + - 901b9c27-1cbe-44ea-94a0-d9c783f17905 + - db9dd27c-e96c-43c2-bf29-991212243453 + - b87999e2-3733-43b1-b8ef-0a297f90bf0f + - 7fe3d786-068f-48c9-ae63-86db2f986c4c + - 4fc61e75-4a11-42e3-b317-348451629bda + - 3e7e37c2-d0cc-41b3-8d8c-6de0bbc69012 + - 63b6e80b-a825-4068-a12a-d12e3edf5f80 + - b11decff-1c07-4c51-a5be-401974ea55ea + - 735cdde6-e540-4dbd-b271-2206e2498156 +filter: + transcription_type: text_line +``` +In this case, run: +```sh +kaldi-data-generator --config config.yaml +``` + +You can also set manually some parameters in the command line. +```sh +kaldi-data-generator --config config.yaml -o my_balsac_kraken -f kraken --image.extraction_mode polygon +``` + +Every run will export a `config.yaml` file and a `param.json` that can be used to reproduce the data generation. + +## Examples + +> :pencil: these corpus ids are from https://demo.arkindex.org/, use `set_demo` + +Two formats are currently supported: `kaldi` and `kraken`. +### Kaldi format + +#### With page ids ```bash -kaldi-data-generator -f kaldi --dataset_name my_balsac --out_dir /tmp/balsac/ --volumes 8f4005e9-1921-47b0-be7b-e27c7fd29486 d2f7c563-1622-4721-bd51-96fab97189f7 +kaldi-data-generator -f kaldi --dataset_name my_balsac --out_dir balsac --select.pages [18c1d2d9-72e8-4f7a-a866-78b59dd407dd,901b9c27-1cbe-44ea-94a0-d9c783f17905,db9dd27c-e96c-43c2-bf29-991212243453] ``` -With corpus ids +#### With volumes ids ```bash -kaldi-data-generator -f kaldi --dataset_name cz --out_dir /tmp/home_cz/ --corpora 1ed45e94-9108-4029-a529-9abe37f55ba0 +kaldi-data-generator -f kaldi --dataset_name my_balsac --out_dir balsac --select.volumes [1d5a26d8-6a3e-45ed-bbb6-5a33d09782aa,46a3426f-86d4-45f1-bd57-0de43cd63efd,85207944-2230-4b76-a98f-735a11506743] ``` -Polygon example: +#### With corpus ids ```bash -kaldi-data-generator -f kaldi --dataset_name my_balsac2 --extraction_mode polygon --out_dir /tmp/balsac/ --pages 50e1c3c0-2fe9-4216-805e-1a2fd2e7e9f4 +kaldi-data-generator -f kaldi --dataset_name my_balsac --out_dir balsac --select.corpora [135eb31f-2c33-4ae3-be4e-2ae9adfd7c75] --select.volume_type page ``` The script creates 3 directories `Lines`, `Transcriptions`, `Partitions` in the specified `out_dir`. @@ -43,15 +104,14 @@ The contents of these directories must be copied (or symlinked) to the correspon ### Kraken format -simple examples: -``` -$ kaldi-data-generator -f kraken -o <output_dir> --volumes <volume_id> --no_split -``` -For instance to download the 4 sets from IAM (2 validation set on Arkindex) in 3 directories : +Create a kraken database based on existing splits: ``` -$ kaldi-data-generator -f kraken -o iam_training --volumes e7a95479-e5fc-4b20-830c-0c6e38bf8f72 --no_split -$ kaldi-data-generator -f kraken -o iam_validation --volumes edc78ee1-09e0-4671-806b-5fc0392707d9 --no_split -$ kaldi-data-generator -f kraken -o iam_validation --volumes fefbbfca-a6dd-4e00-8797-0d4628cb024d --no_split -$ kaldi-data-generator -f kraken -o iam_test --volumes 0ce2b631-01d7-49bf-b213-ceb6eae74a9b --no_split +kaldi-data-generator -f kraken -o balsac_training --filter.volumes [e7a95479-e5fc-4b20-830c-0c6e38bf8f72] --split.no_split +$ kaldi-data-generator -f kraken -o balsac_validation --filter.volumes [db9dd27c-e96c-43c2-bf29-991212243453] --split.no_split +$ kaldi-data-generator -f kraken -o balsac_test --filter.volumes [901b9c27-1cbe-44ea-94a0-d9c783f17905] --split.no_split ``` +## TODO +* Pylaia format +* DAN format +* Resize image (fixed height, fixed_width, rescale...) \ No newline at end of file diff --git a/kaldi_data_generator/arguments.py b/kaldi_data_generator/arguments.py new file mode 100644 index 0000000..9c1d8a6 --- /dev/null +++ b/kaldi_data_generator/arguments.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +import getpass +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional + +USER = getpass.getuser() + + +class Style(Enum): + handwritten: str = "handwritten" + typewritten: str = "typewritten" + other: str = "other" + + +class ExtractionMode(Enum): + boundingRect: str = "boundingRect" + min_area_rect: str = "min_area_rect" + deskew_min_area_rect: str = "deskew_min_area_rect" + skew_min_area_rect: str = "skew_min_area_rect" + polygon: str = "polygon" + skew_polygon: str = "skew_polygon" + deskew_polygon: str = "deskew_polygon" + + +class TranscriptionType(Enum): + word: str = "word" + text_line: str = "text_line" + half_subject_line: str = "half_subject_line" + text_zone: str = "text_zone" + paragraph: str = "paragraph" + act: str = "act" + page: str = "page" + + +@dataclass +class SelectArgs: + """ + Arguments to select elements from Arkindex + + Args: + corpora (list): List of corpus ids to be used. + volumes (list): List of volume ids to be used. + folders (list): List of folder ids to be used. Elements of `volume_type` will be searched recursively in these folders + pages (list): List of page ids to be used. + selection (bool): Get elements from selection + volume_type (str): Volumes (1 level above page) may have a different name on corpora + """ + + corpora: Optional[List[str]] = field(default_factory=list) + volumes: Optional[List[str]] = field(default_factory=list) + folders: Optional[List[str]] = field(default_factory=list) + pages: Optional[List[str]] = field(default_factory=list) + selection: bool = False + volume_type: str = "volume" + + +@dataclass +class CommonArgs: + """ + General arguments + + Args: + cache_dir (str): Cache directory where to save the full size downloaded images. + log_parameters (bool): Save every parameters to a JSON file. + """ + + cache_dir: str = f"/tmp/kaldi_data_generator_{USER}/cache/" + log_parameters: bool = True + + +@dataclass +class SplitArgs: + """ + Arguments related to data splitting into training, validation and test subsets. + + Args: + train_ratio (float): Ratio of data to be used in the training set. Should be between 0 and 1. + test_ratio (float): Ratio of data to be used in the testing set. Should be between 0 and 1. + val_ratio (float): Ratio of data to be used in the validation set. The sum of three variables should equal 1. + use_existing_split (bool): Use an existing split instead of random. Expecting line_ids to be prefixed with (train, val and test). + split_only (bool): Create the split from already downloaded lines, don't download the lines + no_split (bool): No splitting of the data to be done just download the line in the right format + """ + + train_ratio: float = 0.8 + test_ratio: float = 0.1 + val_ratio: float = 1 - train_ratio - test_ratio + use_existing_split: bool = False + split_only: bool = False + no_split: bool = False + + +@dataclass +class ImageArgs: + """ + Arguments related to image transformation. + + Args: + extraction_mode: Mode for extracting the line images: {[e.name for e in Extraction]}, + max_deskew_angle: Maximum angle by which deskewing is allowed to rotate the line image. + the angle determined by deskew tool is bigger than max then that line won't be deskewed/rotated. + skew_angle: Angle by which the line image will be rotated. Useful for data augmnetation" + creating skewed text lines for a more robust model. Only used with skew_* extraction modes. + should_rotate (bool): Use text line rotation class to rotate lines if possible + grayscale (bool): Convert images to grayscale (By default grayscale) + scale_x (float): Ratio of how much to scale the polygon horizontally (1.0 means no rescaling) + scale_y_top (float): Ratio of how much to scale the polygon vertically on the top (1.0 means no rescaling) + scale_y_bottom (float): Ratio of how much to scale the polygon vertically on the bottom (1.0 means no rescaling) + """ + + extraction_mode: ExtractionMode = ExtractionMode.deskew_min_area_rect + max_deskew_angle: int = 45 + skew_angle: int = 0 + should_rotate: bool = False + grayscale: bool = True + scale_x: Optional[float] = None + scale_y_top: Optional[float] = None + scale_y_bottom: Optional[float] = None + + +@dataclass +class FilterArgs: + """ + Arguments related to element filtering. + + Args: + transcription_type: Which type of elements' transcriptions to use? (page, paragraph, text_line, etc) + ignored_classes: List of ignored ml_class names. Filter lines by class + accepted_classes: List of accepted ml_class names. Filter lines by class + accepted_worker_version_ids:List of accepted worker version ids. Filter transcriptions by worker version ids. + The order is important - only up to one transcription will be chosen per element (text_line) + and the worker version order defines the precedence. If there exists a transcription for + the first worker version then it will be chosen, otherwise will continue on to the next + worker version. + Use `--accepted_worker_version_ids manual` to get only manual transcriptions + style: Filter line images by style class. 'other' corresponds to line elements that have neither + handwritten or typewritten class : {[s.name for s in Style]} + """ + + transcription_type: TranscriptionType = TranscriptionType.text_line + ignored_classes: List[str] = field(default_factory=list) + accepted_classes: List[str] = field(default_factory=list) + accepted_worker_version_ids: List[str] = field(default_factory=list) + skip_vertical_lines: bool = False + style: Style = None diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py index 3ae9200..538e720 100644 --- a/kaldi_data_generator/main.py +++ b/kaldi_data_generator/main.py @@ -1,8 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - import argparse -import getpass import os import random from collections import Counter, defaultdict @@ -18,10 +16,20 @@ from arkindex import options_from_env from line_image_extractor.extractor import extract, read_img, save_img from line_image_extractor.image_utils import WHITE, Extraction, rotate_and_trim +import jsonargparse +from kaldi_data_generator.arguments import ( + CommonArgs, + FilterArgs, + ImageArgs, + SelectArgs, + SplitArgs, + Style, +) from kaldi_data_generator.image_utils import download_image, resize_transcription_data from kaldi_data_generator.utils import ( CachedApiClient, TranscriptionData, + export_parameters, logger, write_file, ) @@ -31,6 +39,7 @@ random.seed(SEED) MANUAL = "manual" TEXT_LINE = "text_line" DEFAULT_RESCALE = 1.0 +STYLE_CLASSES = [el.value for el in Style] ROTATION_CLASSES_TO_ANGLES = { "rotate_0": 0, @@ -46,67 +55,46 @@ def create_api_client(cache_dir=None): return CachedApiClient(cache_root=cache_dir, **options_from_env()) -class Style(Enum): - handwritten: str = "handwritten" - typewritten: str = "typewritten" - other: str = "other" - - -STYLE_CLASSES = [s.name for s in [Style.handwritten, Style.typewritten]] - - class HTRDataGenerator: def __init__( self, format, - dataset_name="foo", - out_dir_base="/tmp/kaldi_data", - grayscale=True, - extraction=Extraction.boundingRect, - accepted_classes=None, - ignored_classes=None, - style=None, - skip_vertical_lines=False, - accepted_worker_version_ids=None, - transcription_type=TEXT_LINE, - max_deskew_angle=45, - skew_angle=0, - should_rotate=False, - scale_x=None, - scale_y_top=None, - scale_y_bottom=None, - cache_dir=None, + dataset_name="my_dataset", + out_dir_base="data", + image=ImageArgs(), + common=CommonArgs(), + filter=FilterArgs(), api_client=None, ): self.format = format self.out_dir_base = out_dir_base self.dataset_name = dataset_name - self.grayscale = grayscale - self.extraction_mode = extraction - self.accepted_classes = accepted_classes - self.ignored_classes = ignored_classes + self.grayscale = image.grayscale + self.extraction_mode = Extraction[image.extraction_mode.value] + self.accepted_classes = filter.accepted_classes + self.ignored_classes = filter.ignored_classes self.should_filter_by_class = bool(self.accepted_classes) or bool( self.ignored_classes ) - self.accepted_worker_version_ids = accepted_worker_version_ids + self.accepted_worker_version_ids = filter.accepted_worker_version_ids self.should_filter_by_worker = bool(self.accepted_worker_version_ids) - self.style = style + self.style = filter.style self.should_filter_by_style = bool(self.style) - self.transcription_type = transcription_type - self.skip_vertical_lines = skip_vertical_lines + self.transcription_type = filter.transcription_type.value + self.skip_vertical_lines = filter.skip_vertical_lines self.skipped_pages_count = 0 self.skipped_vertical_lines_count = 0 self.accepted_lines_count = 0 - self.max_deskew_angle = max_deskew_angle - self.skew_angle = skew_angle - self.should_rotate = should_rotate - if scale_x or scale_y_top or scale_y_bottom: + self.max_deskew_angle = image.max_deskew_angle + self.skew_angle = image.skew_angle + self.should_rotate = image.should_rotate + if image.scale_x or image.scale_y_top or image.scale_y_bottom: self.should_resize_polygons = True # use 1.0 as default - no resize, if not specified - self.scale_x = scale_x or DEFAULT_RESCALE - self.scale_y_top = scale_y_top or DEFAULT_RESCALE - self.scale_y_bottom = scale_y_bottom or DEFAULT_RESCALE + self.scale_x = image.scale_x or DEFAULT_RESCALE + self.scale_y_top = image.scale_y_top or DEFAULT_RESCALE + self.scale_y_bottom = image.scale_y_bottom or DEFAULT_RESCALE else: self.should_resize_polygons = False self.api_client = api_client @@ -129,7 +117,7 @@ class HTRDataGenerator: ) os.makedirs(self.out_line_img_dir, exist_ok=True) - self.cache_dir = cache_dir + self.cache_dir = Path(common.cache_dir) logger.info(f"Setting up cache to {self.cache_dir}") self.img_cache_dir = self.cache_dir / "images" self.img_cache_dir.mkdir(exist_ok=True, parents=True) @@ -486,17 +474,24 @@ class HTRDataGenerator: trans.text, ) - def run_selection(self): + def run_selection(self, select): + """ + Update select to keep track of selected ids. + """ selected_elems = [e for e in self.api_client.paginate("ListSelection")] for elem_type, elems_of_type in groupby( selected_elems, key=lambda x: x["type"] ): + elements_ids = [el["id"] for el in elems_of_type] if elem_type == "page": - self.run_pages(list(elems_of_type)) - elif elem_type in ["volume", "folder"]: - self.run_volumes(list(elems_of_type)) + select.pages += elements_ids + elif elem_type == "volume": + select.volumes += elements_ids + elif elem_type == "folder": + select.folders += elements_ids else: raise ValueError(f"Unsupported element type {elem_type} in selection!") + return select def run_pages(self, pages: list): if all(isinstance(n, str) for n in pages): @@ -548,15 +543,9 @@ class HTRDataGenerator: class Split(Enum): - Train: int = 0 - Test: int = 1 - Validation: int = 2 - - @property - def short_name(self) -> str: - if self == self.Validation: - return "val" - return self.name.lower() + Train: str = "train" + Test: str = "test" + Validation: str = "val" class KaldiPartitionSplitter: @@ -564,44 +553,56 @@ class KaldiPartitionSplitter: self, out_dir_base="/tmp/kaldi_data", split_train_ratio=0.8, + split_val_ratio=0.1, split_test_ratio=0.1, use_existing_split=False, ): self.out_dir_base = out_dir_base self.split_train_ratio = split_train_ratio self.split_test_ratio = split_test_ratio - self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio + self.split_val_ratio = split_val_ratio self.use_existing_split = use_existing_split def page_level_split(self, line_ids: list) -> dict: - # need to sort again, because `set` will lose the order - page_ids = sorted({"_".join(line_id.split("_")[:-2]) for line_id in line_ids}) + """ + Split pages into train, validation and test subsets. + Don't split lines to avoid data leakage. + line_ids (list): a list of line ids named {page_id}_{line_number}_{line_id} + """ + # Get page ids from line ids to create splits at page level + page_ids = ["_".join(line_id.split("_")[:-2]) for line_id in line_ids] + # Remove duplicates and sort for reproducibility + page_ids = sorted(set(page_ids)) random.Random(SEED).shuffle(page_ids) page_count = len(page_ids) - train_page_ids = page_ids[: round(page_count * self.split_train_ratio)] - page_ids = page_ids[round(page_count * self.split_train_ratio) :] - - test_page_ids = page_ids[: round(page_count * self.split_test_ratio)] - page_ids = page_ids[round(page_count * self.split_test_ratio) :] - - val_page_ids = page_ids + # Use np.split to split in three sets + stop_train_idx = round(page_count * self.split_train_ratio) + stop_val_idx = stop_train_idx + round(page_count * self.split_val_ratio) + train_page_ids, val_page_ids, test_page_ids = np.split( + page_ids, [stop_train_idx, stop_val_idx] + ) + # Build dictionary that will be used to split lines {id: split} page_dict = {page_id: Split.Train.value for page_id in train_page_ids} - page_dict.update({page_id: Split.Test.value for page_id in test_page_ids}) page_dict.update({page_id: Split.Validation.value for page_id in val_page_ids}) + page_dict.update({page_id: Split.Test.value for page_id in test_page_ids}) return page_dict def existing_split(self, line_ids: list) -> list: - split_dict = {split.short_name: [] for split in Split} + """ + Expect line_ids to be named {split}/{path_to_image} where split in ["train", "val", "test"] + """ + split_dict = {split: [] for split in Split} for line_id in line_ids: split_prefix = line_id.split("/")[0].lower() split_dict[split_prefix].append(line_id) - splits = [split_dict[split.short_name] for split in Split] - return splits + return split_dict def create_partitions(self): - logger.info("Creating partitions") + """ """ + logger.info(f"Creating {[split.value for split in Split]} partitions") + # Get all images ids (and remove extension) lines_path = Path(f"{self.out_dir_base}/Lines") line_ids = [ str(file.relative_to(lines_path).with_suffix("")) @@ -613,7 +614,8 @@ class KaldiPartitionSplitter: datasets = self.existing_split(line_ids) else: page_dict = self.page_level_split(line_ids) - datasets = [[] for _ in range(3)] + # extend this split for lines + datasets = {s.value: [] for s in Split} for line_id in line_ids: page_id = "_".join(line_id.split("_")[:-2]) split_id = page_dict[page_id] @@ -621,315 +623,157 @@ class KaldiPartitionSplitter: partitions_dir = os.path.join(self.out_dir_base, "Partitions") os.makedirs(partitions_dir, exist_ok=True) - for i, dataset in enumerate(datasets): - if not dataset: - logger.info(f"Partition {Split(i).name} is empty! Skipping..") + for split, split_line_ids in datasets.items(): + if not split_line_ids: + logger.info(f"Partition {split} is empty! Skipping...") continue - file_name = f"{partitions_dir}/{Split(i).name}Lines.lst" - write_file(file_name, "\n".join(dataset) + "\n") + file_name = f"{partitions_dir}/{Split(split).name}Lines.lst" + write_file(file_name, "\n".join(split_line_ids) + "\n") + return datasets -def create_parser(): - user_name = getpass.getuser() +def run(format, dataset_name, out_dir, common, image, split, select, filter): - parser = argparse.ArgumentParser( - description="Script to generate Kaldi or kraken training data from annotations from Arkindex", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "-f", - "--format", - type=str, - help="is the data generated going to be used for kaldi or kraken", - ) - parser.add_argument( - "-n", - "--dataset_name", - type=str, - help="Name of the dataset being created for kaldi or kraken " - "(useful for distinguishing different datasets when in Lines or Transcriptions directory)", - ) - parser.add_argument( - "-o", "--out_dir", type=str, required=True, help="output directory" - ) - parser.add_argument( - "--train_ratio", - type=float, - default=0.8, - help="Ratio of pages to be used in train (between 0 and 1)", - ) - parser.add_argument( - "--test_ratio", - type=float, - default=0.1, - help="Ratio of pages to be used in test (between 0 and 1 - train_ratio)", - ) - parser.add_argument( - "--use_existing_split", - action="store_true", - default=False, - help="Use an existing split instead of random. " - "Expecting line_ids to be prefixed with (train, val and test)", - ) - parser.add_argument( - "--split_only", - "--no_download", - action="store_true", - default=False, - help="Create the split from already downloaded lines, don't download the lines", - ) - parser.add_argument( - "--no_split", - action="store_true", - default=False, - help="No splitting of the data to be done just download the line in the right format", - ) + api_client = create_api_client(Path(common.cache_dir)) - parser.add_argument( - "-e", - "--extraction_mode", - type=lambda x: Extraction[x], - default=Extraction.boundingRect, - help=f"Mode for extracting the line images: {[e.name for e in Extraction]}", - ) - - parser.add_argument( - "--max_deskew_angle", - type=int, - default=45, - help="Maximum angle by which deskewing is allowed to rotate the line image. " - "If the angle determined by deskew tool is bigger than max " - "then that line won't be deskewed/rotated.", - ) - - parser.add_argument( - "--skew_angle", - type=int, - default=0, - help="Angle by which the line image will be rotated. Useful for data augmnetation" - " - creating skewed text lines for a more robust model." - " Only used with skew_* extraction modes.", - ) + if not split.split_only: + data_generator = HTRDataGenerator( + format=format, + dataset_name=dataset_name, + out_dir_base=out_dir, + common=common, + image=image, + filter=filter, + api_client=api_client, + ) - parser.add_argument( - "--should_rotate", - action="store_true", - default=False, - help="Use text line rotation class to rotate lines if possible", - ) + # extract all the lines and transcriptions + if select.selection: + select = data_generator.run_selection(select) + if select.pages: + data_generator.run_pages(select.pages) + if select.volumes: + data_generator.run_volumes(select.volumes) + if select.folders: + data_generator.run_folders(select.folders, select.volume_type) + if select.corpora: + data_generator.run_corpora(select.corpora, select.volume_type) + if data_generator.skipped_vertical_lines_count > 0: + logger.info( + f"Number of skipped pages: {data_generator.skipped_pages_count}" + ) + _skipped_vertical_count = data_generator.skipped_vertical_lines_count + _total_count = _skipped_vertical_count + data_generator.accepted_lines_count + skipped_ratio = _skipped_vertical_count / _total_count * 100 - parser.add_argument( - "--transcription_type", - type=str, - default="text_line", - help="Which type of elements' transcriptions to use? (page, paragraph, text_line, etc)", - ) + logger.info( + f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({round(skipped_ratio, 2)}%)" + ) + else: + logger.info("Creating a split from already downloaded files") + data_generator = None - group = parser.add_mutually_exclusive_group(required=False) - group.add_argument( - "--grayscale", - action="store_true", - dest="grayscale", - help="Convert images to grayscale (By default grayscale)", - ) - group.add_argument( - "--color", action="store_false", dest="grayscale", help="Use color images" - ) - group.set_defaults(grayscale=True) + if not split.no_split: + kaldi_partitioner = KaldiPartitionSplitter( + out_dir_base=out_dir, + split_train_ratio=split.train_ratio, + split_val_ratio=split.val_ratio, + split_test_ratio=split.test_ratio, + use_existing_split=split.use_existing_split, + ) - parser.add_argument( - "--corpora", - nargs="*", - help="List of corpus ids to be used, separated by spaces", - ) - parser.add_argument( - "--folders", - type=str, - nargs="*", - help="List of folder ids to be used, separated by spaces. " - "Elements of `volume_type` will be searched recursively in these folders", - ) - parser.add_argument( - "--volumes", - nargs="*", - help="List of volume ids to be used, separated by spaces", - ) - parser.add_argument( - "--pages", nargs="*", help="List of page ids to be used, separated by spaces" - ) - parser.add_argument( - "-v", - "--volume_type", - type=str, - default="volume", - help="Volumes (1 level above page) may have a different name on corpora", - ) - parser.add_argument( - "--selection", - action="store_true", - default=False, - help="Get elements from selection", - ) + # create partitions from all the extracted data + datasets = kaldi_partitioner.create_partitions() + else: + logger.info("No split to be done") + datasets = {} - parser.add_argument( - "--skip_vertical_lines", - action="store_true", - default=False, - help="skips vertical lines when downloading", - ) + logger.info("DONE") - parser.add_argument( - "--ignored_classes", - nargs="*", - default=[], - help="List of ignored ml_class names. Filter lines by class", + export_parameters( + format, + dataset_name, + out_dir, + common, + image, + split, + select, + filter, + datasets, + arkindex_api_url=options_from_env()["base_url"], ) - parser.add_argument( - "--accepted_classes", - nargs="*", - default=[], - help="List of accepted ml_class names. Filter lines by class", + logger.warning( + f"Consider cleaning your cache directory {common.cache_dir} if you are done." ) - parser.add_argument( - "--accepted_worker_version_ids", - nargs="*", - default=[], - help="List of accepted worker version ids. Filter transcriptions by worker version ids." - "The order is important - only up to one transcription will be chosen per element (text_line)" - " and the worker version order defines the precedence. If there exists a transcription for" - " the first worker version then it will be chosen, otherwise will continue on to the next" - " worker version." - " Use `--accepted_worker_version_ids manual` to get only manual transcriptions", - ) - parser.add_argument( - "--style", - type=lambda x: Style[x.lower()], - default=None, - help=f"Filter line images by style class. 'other' corresponds to line elements that " - f"have neither handwritten or typewritten class : {[s.name for s in Style]}", +def get_args(): + parser = jsonargparse.ArgumentParser( + description="Script to generate Kaldi or kraken training data from annotations from Arkindex", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + parse_as_dict=True, ) - parser.add_argument( - "--scale_x", - type=float, - default=None, - help="Ratio of how much to scale the polygon horizontally (1.0 means no rescaling)", + "--config", action=jsonargparse.ActionConfigFile, help="Configuration file" ) parser.add_argument( - "--scale_y_top", - type=float, - default=None, - help="Ratio of how much to scale the polygon vertically on the top (1.0 means no rescaling)", + "-f", + "--format", + type=str, + required=True, + help="is the data generated going to be used for kaldi or kraken", ) - parser.add_argument( - "--scale_y_bottom", - type=float, - default=None, - help="Ratio of how much to scale the polygon vertically on the bottom (1.0 means no rescaling)", + "-d", + "--dataset_name", + type=str, + required=True, + help="Name of the dataset being created for kaldi or kraken " + "(useful for distinguishing different datasets when in Lines or Transcriptions directory)", ) - parser.add_argument( - "--cache_dir", - type=Path, - default=Path(f"/tmp/kaldi_data_generator_{user_name}/cache/"), - help="Cache dir where to save the full size downloaded images. Change it to force redownload.", + "-o", "--out_dir", type=str, required=True, help="output directory" ) - return parser - - -def main(): - parser = create_parser() - args = parser.parse_args() - - if not args.dataset_name and not args.split_only and not args.format == "kraken": - parser.error("--dataset_name must be specified (unless --split-only)") - - if args.accepted_classes and args.ignored_classes: - if set(args.accepted_classes) & set(args.ignored_classes): + parser.add_class_arguments(CommonArgs, "common") + parser.add_class_arguments(ImageArgs, "image") + parser.add_class_arguments(SplitArgs, "split") + parser.add_class_arguments(FilterArgs, "filter") + parser.add_class_arguments(SelectArgs, "select") + + args = parser.parse_args(with_meta=False) + args["common"] = CommonArgs(**args["common"]) + args["image"] = ImageArgs(**args["image"]) + args["split"] = SplitArgs(**args["split"]) + args["select"] = SelectArgs(**args["select"]) + args["filter"] = FilterArgs(**args["filter"]) + + # Check overlap of accepted and ignored classes + accepted_classes = args["filter"].accepted_classes + ignored_classes = args["filter"].accepted_classes + if accepted_classes and ignored_classes: + if set(accepted_classes) & set(ignored_classes): parser.error( - f"--accepted_classes and --ignored_classes values must not overlap ({args.accepted_classes} - {args.ignored_classes})" + f"--filter.accepted_classes and --filter.ignored_classes values must not overlap ({accepted_classes} - {ignored_classes})" ) - if args.style and (args.accepted_classes or args.ignored_classes): - if set(STYLE_CLASSES) & ( - set(args.accepted_classes) | set(args.ignored_classes) - ): + if args["filter"].style and (accepted_classes or ignored_classes): + if set(STYLE_CLASSES) & (set(accepted_classes) | set(ignored_classes)): parser.error( f"--style class values ({STYLE_CLASSES}) shouldn't be in the accepted_classes list " f"(or ignored_classes list) " - "if both --style and --accepted_classes (or --ignored_classes) are used together." - ) - - logger.info(f"ARGS {args} \n") - - api_client = create_api_client(args.cache_dir) - - if not args.split_only: - data_generator = HTRDataGenerator( - format=args.format, - dataset_name=args.dataset_name, - out_dir_base=args.out_dir, - grayscale=args.grayscale, - extraction=args.extraction_mode, - accepted_classes=args.accepted_classes, - ignored_classes=args.ignored_classes, - style=args.style, - skip_vertical_lines=args.skip_vertical_lines, - transcription_type=args.transcription_type, - accepted_worker_version_ids=args.accepted_worker_version_ids, - max_deskew_angle=args.max_deskew_angle, - skew_angle=args.skew_angle, - should_rotate=args.should_rotate, - scale_x=args.scale_x, - scale_y_top=args.scale_y_top, - scale_y_bottom=args.scale_y_bottom, - cache_dir=args.cache_dir, - api_client=api_client, - ) - - # extract all the lines and transcriptions - if args.selection: - data_generator.run_selection() - if args.pages: - data_generator.run_pages(args.pages) - if args.volumes: - data_generator.run_volumes(args.volumes) - if args.folders: - data_generator.run_folders(args.folders, args.volume_type) - if args.corpora: - data_generator.run_corpora(args.corpora, args.volume_type) - if data_generator.skipped_vertical_lines_count > 0: - logger.info( - f"Number of skipped pages: {data_generator.skipped_pages_count}" + "if --filter.style is used with either --filter.accepted_classes or --filter.ignored_classes." ) - _skipped_vertical_count = data_generator.skipped_vertical_lines_count - _total_count = _skipped_vertical_count + data_generator.accepted_lines_count - skipped_ratio = _skipped_vertical_count / _total_count * 100 - logger.info( - f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({round(skipped_ratio, 2)}%)" - ) - else: - logger.info("Creating a split from already downloaded files") - if not args.no_split: - kaldi_partitioner = KaldiPartitionSplitter( - out_dir_base=args.out_dir, - split_train_ratio=args.train_ratio, - split_test_ratio=args.test_ratio, - use_existing_split=args.use_existing_split, - ) + del args["config"] + return args - # create partitions from all the extracted data - kaldi_partitioner.create_partitions() - else: - logger.info("No split to be done") - logger.info("DONE") +def main(): + args = get_args() + logger.info(f"Arguments: {args} \n") + run(**args) if __name__ == "__main__": diff --git a/kaldi_data_generator/utils.py b/kaldi_data_generator/utils.py index 0aed87a..59e0bbe 100644 --- a/kaldi_data_generator/utils.py +++ b/kaldi_data_generator/utils.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- +import getpass import json import logging import os +import socket import sys +from datetime import datetime from pathlib import Path import cv2 @@ -10,6 +13,8 @@ import numpy as np from arkindex import ArkindexClient from line_image_extractor.image_utils import BoundingBox +import yaml + logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s/%(name)s: %(message)s" ) @@ -132,3 +137,74 @@ class CachedApiClient(ArkindexClient): request_cache.parent.mkdir(parents=True, exist_ok=True) request_cache.write_text(json.dumps(results)) logger.info("Saved") + + +def write_json(d, filename): + with open(filename, "w") as f: + f.write(json.dumps(d, indent=4)) + + +def write_yaml(d, filename): + with open(filename, "w") as f: + yaml.dump(d, f) + + +def export_parameters( + format, + dataset_name, + out_dir, + common, + image, + split, + select, + filter, + datasets, + arkindex_api_url, +): + """ + Dump a JSON log file to keep track of parameters for this dataset + """ + # Get config dict + config = { + "format": format, + "dataset_name": dataset_name, + "out_dir": out_dir, + "common": vars(common), + "image": vars(image), + "split": vars(split), + "select": vars(select), + "filter": vars(filter), + } + # Serialize special classes + image.extraction_mode = image.extraction_mode.name + filter.transcription_type = filter.transcription_type.value + common.cache_dir = str(common.cache_dir) + + if common.log_parameters: + # Get additional info on dataset and user + current_datetime = datetime.now().strftime("%Y-%m-%d|%H-%M-%S") + + parameters = {} + parameters["config"] = config + parameters["dataset"] = ( + datasets + if isinstance(datasets, dict) + else {"train": datasets[0], "valid": datasets[1], "test": datasets[2]} + ) + parameters["info"] = { + "user": getpass.getuser(), + "date": current_datetime, + "device": socket.gethostname(), + "arkindex_api_url": arkindex_api_url, + } + + # Export + parameter_file = os.path.join( + out_dir, f"param-{dataset_name}-{format}-{current_datetime}.json" + ) + write_json(parameters, parameter_file) + logger.info(f"Parameters exported in file {parameter_file}") + + config_file = os.path.join(out_dir, "config.yaml") + write_yaml(config, config_file) + logger.info(f"Config exported in file {config_file}") diff --git a/requirements.txt b/requirements.txt index f46a80b..cca4484 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ apistar==0.7.2 arkindex-client==1.0.9 +jsonargparse teklia-line-image-extractor==0.2.4 tqdm==4.64.0 - typesystem==0.2.5 diff --git a/tests/data/pinned_insects/partitions/TestLines.lst b/tests/data/pinned_insects/partitions/TestLines.lst index 1ce3c45..4de4011 100644 --- a/tests/data/pinned_insects/partitions/TestLines.lst +++ b/tests/data/pinned_insects/partitions/TestLines.lst @@ -1,6 +1,2 @@ -testing/5bcb0a49-4810-4107-9b7f-62c30f913399_005_74e593f4-e5d1-4891-ba99-e9eabc1c9a9b -testing/5bcb0a49-4810-4107-9b7f-62c30f913399_006_51ae11a8-3fb5-47f7-80d6-48b2b235076b -testing/5bcb0a49-4810-4107-9b7f-62c30f913399_007_8d35675d-9f43-4b69-9ad7-e319dfb3be6e -testing/5bcb0a49-4810-4107-9b7f-62c30f913399_008_b4b16c14-ce22-4f13-88f8-7c8e4e32071f -testing/5bcb0a49-4810-4107-9b7f-62c30f913399_009_4ab037a2-451b-40b9-b598-1d14984dd8fd -testing/5bcb0a49-4810-4107-9b7f-62c30f913399_010_52260fd8-cd63-463e-9f74-3ed3a97d14c5 \ No newline at end of file +testing/6139d79e-be7b-46fd-9cde-20268bfb5114_005_3750213d-2db3-4425-9b6c-0e5c0f45596f +testing/6139d79e-be7b-46fd-9cde-20268bfb5114_006_974c3bd8-d42a-4316-a021-345d6b11379a \ No newline at end of file diff --git a/tests/data/pinned_insects/partitions/ValidationLines.lst b/tests/data/pinned_insects/partitions/ValidationLines.lst index 4de4011..1ce3c45 100644 --- a/tests/data/pinned_insects/partitions/ValidationLines.lst +++ b/tests/data/pinned_insects/partitions/ValidationLines.lst @@ -1,2 +1,6 @@ -testing/6139d79e-be7b-46fd-9cde-20268bfb5114_005_3750213d-2db3-4425-9b6c-0e5c0f45596f -testing/6139d79e-be7b-46fd-9cde-20268bfb5114_006_974c3bd8-d42a-4316-a021-345d6b11379a \ No newline at end of file +testing/5bcb0a49-4810-4107-9b7f-62c30f913399_005_74e593f4-e5d1-4891-ba99-e9eabc1c9a9b +testing/5bcb0a49-4810-4107-9b7f-62c30f913399_006_51ae11a8-3fb5-47f7-80d6-48b2b235076b +testing/5bcb0a49-4810-4107-9b7f-62c30f913399_007_8d35675d-9f43-4b69-9ad7-e319dfb3be6e +testing/5bcb0a49-4810-4107-9b7f-62c30f913399_008_b4b16c14-ce22-4f13-88f8-7c8e4e32071f +testing/5bcb0a49-4810-4107-9b7f-62c30f913399_009_4ab037a2-451b-40b9-b598-1d14984dd8fd +testing/5bcb0a49-4810-4107-9b7f-62c30f913399_010_52260fd8-cd63-463e-9f74-3ed3a97d14c5 \ No newline at end of file diff --git a/tests/test_main.py b/tests/test_main.py index 9273975..853239a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,15 +4,12 @@ from pathlib import Path import pytest +from kaldi_data_generator.arguments import FilterArgs from kaldi_data_generator.main import MANUAL, HTRDataGenerator, KaldiPartitionSplitter def test_init(): - - htr_data_gen = HTRDataGenerator( - format="kaldi", accepted_worker_version_ids=[], cache_dir=Path("/tmp/foo/bar") - ) - + htr_data_gen = HTRDataGenerator(format="kaldi") assert htr_data_gen is not None @@ -38,16 +35,14 @@ def test_run_volumes_with_worker_version( out_dir_base = tmpdir htr_data_gen = HTRDataGenerator( format="kaldi", - accepted_worker_version_ids=worker_version_ids, + filter=FilterArgs(accepted_worker_version_ids=worker_version_ids), out_dir_base=out_dir_base, - cache_dir=Path("/tmp/foo/bar"), ) htr_data_gen.api_client = api_client htr_data_gen.get_image = mocker.MagicMock() # return same fake image for all the pages htr_data_gen.get_image.return_value = fake_image - htr_data_gen.run_volumes([fake_volume_id]) trans_files = list(Path(htr_data_gen.out_line_text_dir).glob("*.txt")) @@ -65,13 +60,12 @@ def test_run_volumes_with_worker_version( def test_create_partitions(fake_expected_partitions, tmpdir): out_dir_base = Path(tmpdir) - splitter = KaldiPartitionSplitter( - out_dir_base=out_dir_base, - ) + splitter = KaldiPartitionSplitter(out_dir_base=out_dir_base) all_ids = [] for split_name, expected_split_ids in fake_expected_partitions.items(): all_ids += expected_split_ids + # shuffle to show that splits are reproducible random.shuffle(all_ids) -- GitLab