From 967fb98705c13e3b95941a09ed2e35fbe3c54446 Mon Sep 17 00:00:00 2001
From: Martin <maarand@teklia.com>
Date: Tue, 5 Nov 2019 13:24:43 +0100
Subject: [PATCH] extract kaldi partition splitter

---
 kaldi_data_generator.py | 48 +++++++++++++++++++++++++----------------
 1 file changed, 29 insertions(+), 19 deletions(-)

diff --git a/kaldi_data_generator.py b/kaldi_data_generator.py
index b7c6126..5825f9f 100644
--- a/kaldi_data_generator.py
+++ b/kaldi_data_generator.py
@@ -56,6 +56,11 @@ class KaldiDataGenerator:
         self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
         self.grayscale = grayscale
 
+        self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name)
+        os.makedirs(self.out_line_text_dir, exist_ok=True)
+        self.out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name)
+        os.makedirs(self.out_line_img_dir, exist_ok=True)
+
     def get_image(self, image_url, page_id):
         out_full_img_dir = os.path.join(self.out_dir_base, 'full', page_id)
         os.makedirs(out_full_img_dir, exist_ok=True)
@@ -89,25 +94,39 @@ class KaldiDataGenerator:
         except ErrorResponse as e:
             print("ListTranscriptions failed", e.status_code, e.title, e.content, page_id)
             raise e
-        print("C", count)
+        print("Num of lines", count)
         full_image_url = res['zone']['image']['s3_url']
 
         img = self.get_image(full_image_url, page_id=page_id)
 
-        out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name, page_id)
-        os.makedirs(out_line_img_dir, exist_ok=True)
         for i, [x, y, w, h] in enumerate(line_bounding_rects):
             cropped = img[y:y + h, x:x + w].copy()
-            cv2.imwrite(f'{out_line_img_dir}_{i}.jpg', cropped)
+            cv2.imwrite(f'{self.out_line_img_dir}_{i}.jpg', cropped)
 
-        out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name, page_id)
-        os.makedirs(out_line_text_dir, exist_ok=True)
         for i, text in enumerate(line_transcriptions):
-            write_file(f"{out_line_text_dir}_{i}.txt", text)
+            write_file(f"{self.out_line_text_dir}_{i}.txt", text)
+
+    def run_pages(self, page_ids):
+        for page_id in page_ids:
+            print("Page", page_id)
+            self.extract_lines(page_id)
+
+    def run_volumes(self, volume_ids):
+        for volume_id in volume_ids:
+            print("Vol", volume_id)
+            page_ids = [page['id'] for page in api_client.paginate('ListElementChildren', id=volume_id)]
+            self.run_pages(page_ids)
+
+
+class KaldiPartitionSplitter:
+
+    def __init__(self, out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1):
+        self.out_dir_base = out_dir_base
+        self.split_train_ratio = split_train_ratio
+        self.split_test_ratio = split_test_ratio
 
     def page_level_split(self, line_ids):
         page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids})
-        # page_ids = list({line_id for line_id in line_ids})
         random.shuffle(page_ids)
         page_count = len(page_ids)
 
@@ -141,16 +160,6 @@ class KaldiDataGenerator:
             file_name = f"{partitions_dir}/{out_file_dict[i]}Lines.lst"
             write_file(file_name, '\n'.join(dataset) + '\n')
 
-    def run_pages(self, page_ids):
-        for page_id in page_ids:
-            print("P", page_id)
-            self.extract_lines(page_id)
-
-    def run_volumes(self, volume_ids):
-        for volume_id in volume_ids:
-            print("V", volume_id)
-            page_ids = [page['id'] for page in api_client.paginate('ListElementChildren', id=volume_id)]
-            self.run_pages(page_ids)
 
 
 example_page_ids = [
@@ -164,7 +173,8 @@ example_volume_ids = [
 ]
 
 kaldi_data_generator = KaldiDataGenerator()
+kaldi_partitioner = KaldiPartitionSplitter()
 
 # kaldi_data_generator.run_page(example_page_ids)
 kaldi_data_generator.run_volumes(example_volume_ids)
-kaldi_data_generator.create_partitions()
+kaldi_partitioner.create_partitions()
-- 
GitLab