Skip to content
Snippets Groups Projects
Commit 788aa635 authored by Martin's avatar Martin
Browse files

add argparse

parent 5b8081dd
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
import random
from enum import Enum
......@@ -44,13 +44,9 @@ def write_file(file_name, content):
class KaldiDataGenerator:
def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1,
grayscale=True):
def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True):
self.out_dir_base = out_dir_base
self.dataset_name = dataset_name
self.split_train_ratio = split_train_ratio
self.split_test_ratio = split_test_ratio
self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
self.grayscale = grayscale
self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name)
......@@ -127,6 +123,7 @@ class KaldiPartitionSplitter:
self.out_dir_base = out_dir_base
self.split_train_ratio = split_train_ratio
self.split_test_ratio = split_test_ratio
self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
def page_level_split(self, line_ids):
page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids})
......@@ -161,23 +158,83 @@ class KaldiPartitionSplitter:
partitions_dir = os.path.join(self.out_dir_base, 'Partitions')
os.makedirs(partitions_dir, exist_ok=True)
for i, dataset in enumerate(datasets):
if not dataset:
print(f"Partition {Split(i).name} is empty! Skipping..")
continue
file_name = f"{partitions_dir}/{Split(i).name}Lines.lst"
write_file(file_name, '\n'.join(dataset) + '\n')
example_page_ids = [
'bf23cc96-f6b2-4182-923e-6c163db37eba',
'7c51e648-370e-43b7-9340-3b1a17c13828',
'56521074-59f4-4173-bfc1-4b1384ff8139',
]
def example():
example_page_ids = [
'bf23cc96-f6b2-4182-923e-6c163db37eba',
'7c51e648-370e-43b7-9340-3b1a17c13828',
'56521074-59f4-4173-bfc1-4b1384ff8139',
]
example_volume_ids = [
'8f4005e9-1921-47b0-be7b-e27c7fd29486',
]
kaldi_data_generator = KaldiDataGenerator()
kaldi_partitioner = KaldiPartitionSplitter()
# kaldi_data_generator.run_page(example_page_ids)
kaldi_data_generator.run_volumes(example_volume_ids)
kaldi_partitioner.create_partitions()
def create_parser():
parser = argparse.ArgumentParser(
description="Script to generate Kaldi training data from annotations from Arkindex")
parser.add_argument('-n', '--dataset_name', type=str, required=True,
help='Name of the dataset being created for kaldi')
parser.add_argument('-o', '--out_dir', type=str, required=True,
help='output directory')
parser.add_argument('--train_ratio', type=float, default=0.8,
help='Ratio of pages to be used in train (between 0 and 1)')
parser.add_argument('--test_ratio', type=float, default=0.1,
help='Ratio of pages to be used in train (between 0 and 1 - train_ratio)')
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument('--grayscale', action='store_true',
help='Convert images to grayscale')
group.add_argument('--color', action='store_false',
help='Use color images')
parser.set_defaults(grayscale=True)
parser.add_argument('--volumes', nargs='*',
help='List of volume ids to be used, separated by spaces')
parser.add_argument('--pages', nargs='*',
help='List of page ids to be used, separated by spaces')
return parser
def main():
args = create_parser().parse_args()
print("ARGS", args, '\n')
kaldi_data_generator = KaldiDataGenerator(dataset_name=args.dataset_name,
out_dir_base=args.out_dir,
grayscale=args.grayscale)
kaldi_partitioner = KaldiPartitionSplitter(out_dir_base=args.out_dir,
split_train_ratio=args.train_ratio,
split_test_ratio=args.test_ratio)
# extract all the lines and transcriptions
if args.pages:
kaldi_data_generator.run_page(args.pages)
if args.volumes:
kaldi_data_generator.run_volumes(args.volumes)
print()
# create partitions from all the extracted data
kaldi_partitioner.create_partitions()
example_volume_ids = [
'8f4005e9-1921-47b0-be7b-e27c7fd29486',
]
print("DONE")
kaldi_data_generator = KaldiDataGenerator()
kaldi_partitioner = KaldiPartitionSplitter()
# kaldi_data_generator.run_page(example_page_ids)
kaldi_data_generator.run_volumes(example_volume_ids)
kaldi_partitioner.create_partitions()
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment