Skip to content
Snippets Groups Projects
Commit 788aa635 authored by Martin's avatar Martin
Browse files

add argparse

parent 5b8081dd
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import argparse
import os import os
import random import random
from enum import Enum from enum import Enum
...@@ -44,13 +44,9 @@ def write_file(file_name, content): ...@@ -44,13 +44,9 @@ def write_file(file_name, content):
class KaldiDataGenerator: class KaldiDataGenerator:
def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1, def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True):
grayscale=True):
self.out_dir_base = out_dir_base self.out_dir_base = out_dir_base
self.dataset_name = dataset_name self.dataset_name = dataset_name
self.split_train_ratio = split_train_ratio
self.split_test_ratio = split_test_ratio
self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
self.grayscale = grayscale self.grayscale = grayscale
self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name) self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name)
...@@ -127,6 +123,7 @@ class KaldiPartitionSplitter: ...@@ -127,6 +123,7 @@ class KaldiPartitionSplitter:
self.out_dir_base = out_dir_base self.out_dir_base = out_dir_base
self.split_train_ratio = split_train_ratio self.split_train_ratio = split_train_ratio
self.split_test_ratio = split_test_ratio self.split_test_ratio = split_test_ratio
self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
def page_level_split(self, line_ids): def page_level_split(self, line_ids):
page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids}) page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids})
...@@ -161,23 +158,83 @@ class KaldiPartitionSplitter: ...@@ -161,23 +158,83 @@ class KaldiPartitionSplitter:
partitions_dir = os.path.join(self.out_dir_base, 'Partitions') partitions_dir = os.path.join(self.out_dir_base, 'Partitions')
os.makedirs(partitions_dir, exist_ok=True) os.makedirs(partitions_dir, exist_ok=True)
for i, dataset in enumerate(datasets): for i, dataset in enumerate(datasets):
if not dataset:
print(f"Partition {Split(i).name} is empty! Skipping..")
continue
file_name = f"{partitions_dir}/{Split(i).name}Lines.lst" file_name = f"{partitions_dir}/{Split(i).name}Lines.lst"
write_file(file_name, '\n'.join(dataset) + '\n') write_file(file_name, '\n'.join(dataset) + '\n')
example_page_ids = [ def example():
'bf23cc96-f6b2-4182-923e-6c163db37eba', example_page_ids = [
'7c51e648-370e-43b7-9340-3b1a17c13828', 'bf23cc96-f6b2-4182-923e-6c163db37eba',
'56521074-59f4-4173-bfc1-4b1384ff8139', '7c51e648-370e-43b7-9340-3b1a17c13828',
] '56521074-59f4-4173-bfc1-4b1384ff8139',
]
example_volume_ids = [
'8f4005e9-1921-47b0-be7b-e27c7fd29486',
]
kaldi_data_generator = KaldiDataGenerator()
kaldi_partitioner = KaldiPartitionSplitter()
# kaldi_data_generator.run_page(example_page_ids)
kaldi_data_generator.run_volumes(example_volume_ids)
kaldi_partitioner.create_partitions()
def create_parser():
parser = argparse.ArgumentParser(
description="Script to generate Kaldi training data from annotations from Arkindex")
parser.add_argument('-n', '--dataset_name', type=str, required=True,
help='Name of the dataset being created for kaldi')
parser.add_argument('-o', '--out_dir', type=str, required=True,
help='output directory')
parser.add_argument('--train_ratio', type=float, default=0.8,
help='Ratio of pages to be used in train (between 0 and 1)')
parser.add_argument('--test_ratio', type=float, default=0.1,
help='Ratio of pages to be used in train (between 0 and 1 - train_ratio)')
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument('--grayscale', action='store_true',
help='Convert images to grayscale')
group.add_argument('--color', action='store_false',
help='Use color images')
parser.set_defaults(grayscale=True)
parser.add_argument('--volumes', nargs='*',
help='List of volume ids to be used, separated by spaces')
parser.add_argument('--pages', nargs='*',
help='List of page ids to be used, separated by spaces')
return parser
def main():
args = create_parser().parse_args()
print("ARGS", args, '\n')
kaldi_data_generator = KaldiDataGenerator(dataset_name=args.dataset_name,
out_dir_base=args.out_dir,
grayscale=args.grayscale)
kaldi_partitioner = KaldiPartitionSplitter(out_dir_base=args.out_dir,
split_train_ratio=args.train_ratio,
split_test_ratio=args.test_ratio)
# extract all the lines and transcriptions
if args.pages:
kaldi_data_generator.run_page(args.pages)
if args.volumes:
kaldi_data_generator.run_volumes(args.volumes)
print()
# create partitions from all the extracted data
kaldi_partitioner.create_partitions()
example_volume_ids = [ print("DONE")
'8f4005e9-1921-47b0-be7b-e27c7fd29486',
]
kaldi_data_generator = KaldiDataGenerator()
kaldi_partitioner = KaldiPartitionSplitter()
# kaldi_data_generator.run_page(example_page_ids) if __name__ == '__main__':
kaldi_data_generator.run_volumes(example_volume_ids) main()
kaldi_partitioner.create_partitions()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment