From 967fb98705c13e3b95941a09ed2e35fbe3c54446 Mon Sep 17 00:00:00 2001 From: Martin <maarand@teklia.com> Date: Tue, 5 Nov 2019 13:24:43 +0100 Subject: [PATCH] extract kaldi partition splitter --- kaldi_data_generator.py | 48 +++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/kaldi_data_generator.py b/kaldi_data_generator.py index b7c6126..5825f9f 100644 --- a/kaldi_data_generator.py +++ b/kaldi_data_generator.py @@ -56,6 +56,11 @@ class KaldiDataGenerator: self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio self.grayscale = grayscale + self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name) + os.makedirs(self.out_line_text_dir, exist_ok=True) + self.out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name) + os.makedirs(self.out_line_img_dir, exist_ok=True) + def get_image(self, image_url, page_id): out_full_img_dir = os.path.join(self.out_dir_base, 'full', page_id) os.makedirs(out_full_img_dir, exist_ok=True) @@ -89,25 +94,39 @@ class KaldiDataGenerator: except ErrorResponse as e: print("ListTranscriptions failed", e.status_code, e.title, e.content, page_id) raise e - print("C", count) + print("Num of lines", count) full_image_url = res['zone']['image']['s3_url'] img = self.get_image(full_image_url, page_id=page_id) - out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name, page_id) - os.makedirs(out_line_img_dir, exist_ok=True) for i, [x, y, w, h] in enumerate(line_bounding_rects): cropped = img[y:y + h, x:x + w].copy() - cv2.imwrite(f'{out_line_img_dir}_{i}.jpg', cropped) + cv2.imwrite(f'{self.out_line_img_dir}_{i}.jpg', cropped) - out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name, page_id) - os.makedirs(out_line_text_dir, exist_ok=True) for i, text in enumerate(line_transcriptions): - write_file(f"{out_line_text_dir}_{i}.txt", text) + write_file(f"{self.out_line_text_dir}_{i}.txt", text) + + def run_pages(self, page_ids): + for page_id in page_ids: + print("Page", page_id) + self.extract_lines(page_id) + + def run_volumes(self, volume_ids): + for volume_id in volume_ids: + print("Vol", volume_id) + page_ids = [page['id'] for page in api_client.paginate('ListElementChildren', id=volume_id)] + self.run_pages(page_ids) + + +class KaldiPartitionSplitter: + + def __init__(self, out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1): + self.out_dir_base = out_dir_base + self.split_train_ratio = split_train_ratio + self.split_test_ratio = split_test_ratio def page_level_split(self, line_ids): page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids}) - # page_ids = list({line_id for line_id in line_ids}) random.shuffle(page_ids) page_count = len(page_ids) @@ -141,16 +160,6 @@ class KaldiDataGenerator: file_name = f"{partitions_dir}/{out_file_dict[i]}Lines.lst" write_file(file_name, '\n'.join(dataset) + '\n') - def run_pages(self, page_ids): - for page_id in page_ids: - print("P", page_id) - self.extract_lines(page_id) - - def run_volumes(self, volume_ids): - for volume_id in volume_ids: - print("V", volume_id) - page_ids = [page['id'] for page in api_client.paginate('ListElementChildren', id=volume_id)] - self.run_pages(page_ids) example_page_ids = [ @@ -164,7 +173,8 @@ example_volume_ids = [ ] kaldi_data_generator = KaldiDataGenerator() +kaldi_partitioner = KaldiPartitionSplitter() # kaldi_data_generator.run_page(example_page_ids) kaldi_data_generator.run_volumes(example_volume_ids) -kaldi_data_generator.create_partitions() +kaldi_partitioner.create_partitions() -- GitLab