From 788aa6354c06538dd289bc908fad9a2e5e977673 Mon Sep 17 00:00:00 2001
From: Martin <maarand@teklia.com>
Date: Tue, 5 Nov 2019 14:38:39 +0100
Subject: [PATCH] add argparse

---
 kaldi_data_generator.py | 95 ++++++++++++++++++++++++++++++++---------
 1 file changed, 76 insertions(+), 19 deletions(-)

diff --git a/kaldi_data_generator.py b/kaldi_data_generator.py
index c02daf1..09bbf8a 100644
--- a/kaldi_data_generator.py
+++ b/kaldi_data_generator.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
-
+import argparse
 import os
 import random
 from enum import Enum
@@ -44,13 +44,9 @@ def write_file(file_name, content):
 
 class KaldiDataGenerator:
 
-    def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1,
-                 grayscale=True):
+    def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True):
         self.out_dir_base = out_dir_base
         self.dataset_name = dataset_name
-        self.split_train_ratio = split_train_ratio
-        self.split_test_ratio = split_test_ratio
-        self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
         self.grayscale = grayscale
 
         self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name)
@@ -127,6 +123,7 @@ class KaldiPartitionSplitter:
         self.out_dir_base = out_dir_base
         self.split_train_ratio = split_train_ratio
         self.split_test_ratio = split_test_ratio
+        self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
 
     def page_level_split(self, line_ids):
         page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids})
@@ -161,23 +158,83 @@ class KaldiPartitionSplitter:
         partitions_dir = os.path.join(self.out_dir_base, 'Partitions')
         os.makedirs(partitions_dir, exist_ok=True)
         for i, dataset in enumerate(datasets):
+            if not dataset:
+                print(f"Partition {Split(i).name} is empty! Skipping..")
+                continue
             file_name = f"{partitions_dir}/{Split(i).name}Lines.lst"
             write_file(file_name, '\n'.join(dataset) + '\n')
 
 
-example_page_ids = [
-    'bf23cc96-f6b2-4182-923e-6c163db37eba',
-    '7c51e648-370e-43b7-9340-3b1a17c13828',
-    '56521074-59f4-4173-bfc1-4b1384ff8139',
-]
+def example():
+    example_page_ids = [
+        'bf23cc96-f6b2-4182-923e-6c163db37eba',
+        '7c51e648-370e-43b7-9340-3b1a17c13828',
+        '56521074-59f4-4173-bfc1-4b1384ff8139',
+    ]
+
+    example_volume_ids = [
+        '8f4005e9-1921-47b0-be7b-e27c7fd29486',
+    ]
+
+    kaldi_data_generator = KaldiDataGenerator()
+    kaldi_partitioner = KaldiPartitionSplitter()
+
+    # kaldi_data_generator.run_page(example_page_ids)
+    kaldi_data_generator.run_volumes(example_volume_ids)
+    kaldi_partitioner.create_partitions()
+
+
+def create_parser():
+    parser = argparse.ArgumentParser(
+        description="Script to generate Kaldi training data from annotations from Arkindex")
+    parser.add_argument('-n', '--dataset_name', type=str, required=True,
+                        help='Name of the dataset being created for kaldi')
+    parser.add_argument('-o', '--out_dir', type=str, required=True,
+                        help='output directory')
+    parser.add_argument('--train_ratio', type=float, default=0.8,
+                        help='Ratio of pages to be used in train (between 0 and 1)')
+    parser.add_argument('--test_ratio', type=float, default=0.1,
+                        help='Ratio of pages to be used in train (between 0 and 1 - train_ratio)')
+
+    group = parser.add_mutually_exclusive_group(required=False)
+    group.add_argument('--grayscale', action='store_true',
+                       help='Convert images to grayscale')
+    group.add_argument('--color', action='store_false',
+                       help='Use color images')
+    parser.set_defaults(grayscale=True)
+
+    parser.add_argument('--volumes', nargs='*',
+                        help='List of volume ids to be used, separated by spaces')
+    parser.add_argument('--pages', nargs='*',
+                        help='List of page ids to be used, separated by spaces')
+
+    return parser
+
+
+def main():
+    args = create_parser().parse_args()
+
+    print("ARGS", args, '\n')
+
+    kaldi_data_generator = KaldiDataGenerator(dataset_name=args.dataset_name,
+                                              out_dir_base=args.out_dir,
+                                              grayscale=args.grayscale)
+
+    kaldi_partitioner = KaldiPartitionSplitter(out_dir_base=args.out_dir,
+                                               split_train_ratio=args.train_ratio,
+                                               split_test_ratio=args.test_ratio)
+    # extract all the lines and transcriptions
+    if args.pages:
+        kaldi_data_generator.run_page(args.pages)
+    if args.volumes:
+        kaldi_data_generator.run_volumes(args.volumes)
+
+    print()
+    # create partitions from all the extracted data
+    kaldi_partitioner.create_partitions()
 
-example_volume_ids = [
-    '8f4005e9-1921-47b0-be7b-e27c7fd29486',
-]
+    print("DONE")
 
-kaldi_data_generator = KaldiDataGenerator()
-kaldi_partitioner = KaldiPartitionSplitter()
 
-# kaldi_data_generator.run_page(example_page_ids)
-kaldi_data_generator.run_volumes(example_volume_ids)
-kaldi_partitioner.create_partitions()
+if __name__ == '__main__':
+    main()
-- 
GitLab