From 788aa6354c06538dd289bc908fad9a2e5e977673 Mon Sep 17 00:00:00 2001 From: Martin <maarand@teklia.com> Date: Tue, 5 Nov 2019 14:38:39 +0100 Subject: [PATCH] add argparse --- kaldi_data_generator.py | 95 ++++++++++++++++++++++++++++++++--------- 1 file changed, 76 insertions(+), 19 deletions(-) diff --git a/kaldi_data_generator.py b/kaldi_data_generator.py index c02daf1..09bbf8a 100644 --- a/kaldi_data_generator.py +++ b/kaldi_data_generator.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- - +import argparse import os import random from enum import Enum @@ -44,13 +44,9 @@ def write_file(file_name, content): class KaldiDataGenerator: - def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1, - grayscale=True): + def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True): self.out_dir_base = out_dir_base self.dataset_name = dataset_name - self.split_train_ratio = split_train_ratio - self.split_test_ratio = split_test_ratio - self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio self.grayscale = grayscale self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name) @@ -127,6 +123,7 @@ class KaldiPartitionSplitter: self.out_dir_base = out_dir_base self.split_train_ratio = split_train_ratio self.split_test_ratio = split_test_ratio + self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio def page_level_split(self, line_ids): page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids}) @@ -161,23 +158,83 @@ class KaldiPartitionSplitter: partitions_dir = os.path.join(self.out_dir_base, 'Partitions') os.makedirs(partitions_dir, exist_ok=True) for i, dataset in enumerate(datasets): + if not dataset: + print(f"Partition {Split(i).name} is empty! Skipping..") + continue file_name = f"{partitions_dir}/{Split(i).name}Lines.lst" write_file(file_name, '\n'.join(dataset) + '\n') -example_page_ids = [ - 'bf23cc96-f6b2-4182-923e-6c163db37eba', - '7c51e648-370e-43b7-9340-3b1a17c13828', - '56521074-59f4-4173-bfc1-4b1384ff8139', -] +def example(): + example_page_ids = [ + 'bf23cc96-f6b2-4182-923e-6c163db37eba', + '7c51e648-370e-43b7-9340-3b1a17c13828', + '56521074-59f4-4173-bfc1-4b1384ff8139', + ] + + example_volume_ids = [ + '8f4005e9-1921-47b0-be7b-e27c7fd29486', + ] + + kaldi_data_generator = KaldiDataGenerator() + kaldi_partitioner = KaldiPartitionSplitter() + + # kaldi_data_generator.run_page(example_page_ids) + kaldi_data_generator.run_volumes(example_volume_ids) + kaldi_partitioner.create_partitions() + + +def create_parser(): + parser = argparse.ArgumentParser( + description="Script to generate Kaldi training data from annotations from Arkindex") + parser.add_argument('-n', '--dataset_name', type=str, required=True, + help='Name of the dataset being created for kaldi') + parser.add_argument('-o', '--out_dir', type=str, required=True, + help='output directory') + parser.add_argument('--train_ratio', type=float, default=0.8, + help='Ratio of pages to be used in train (between 0 and 1)') + parser.add_argument('--test_ratio', type=float, default=0.1, + help='Ratio of pages to be used in train (between 0 and 1 - train_ratio)') + + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument('--grayscale', action='store_true', + help='Convert images to grayscale') + group.add_argument('--color', action='store_false', + help='Use color images') + parser.set_defaults(grayscale=True) + + parser.add_argument('--volumes', nargs='*', + help='List of volume ids to be used, separated by spaces') + parser.add_argument('--pages', nargs='*', + help='List of page ids to be used, separated by spaces') + + return parser + + +def main(): + args = create_parser().parse_args() + + print("ARGS", args, '\n') + + kaldi_data_generator = KaldiDataGenerator(dataset_name=args.dataset_name, + out_dir_base=args.out_dir, + grayscale=args.grayscale) + + kaldi_partitioner = KaldiPartitionSplitter(out_dir_base=args.out_dir, + split_train_ratio=args.train_ratio, + split_test_ratio=args.test_ratio) + # extract all the lines and transcriptions + if args.pages: + kaldi_data_generator.run_page(args.pages) + if args.volumes: + kaldi_data_generator.run_volumes(args.volumes) + + print() + # create partitions from all the extracted data + kaldi_partitioner.create_partitions() -example_volume_ids = [ - '8f4005e9-1921-47b0-be7b-e27c7fd29486', -] + print("DONE") -kaldi_data_generator = KaldiDataGenerator() -kaldi_partitioner = KaldiPartitionSplitter() -# kaldi_data_generator.run_page(example_page_ids) -kaldi_data_generator.run_volumes(example_volume_ids) -kaldi_partitioner.create_partitions() +if __name__ == '__main__': + main() -- GitLab