Skip to content
Snippets Groups Projects
Commit 967fb987 authored by Martin's avatar Martin
Browse files

extract kaldi partition splitter

parent 4a76d32d
No related branches found
No related tags found
No related merge requests found
......@@ -56,6 +56,11 @@ class KaldiDataGenerator:
self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
self.grayscale = grayscale
self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name)
os.makedirs(self.out_line_text_dir, exist_ok=True)
self.out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name)
os.makedirs(self.out_line_img_dir, exist_ok=True)
def get_image(self, image_url, page_id):
out_full_img_dir = os.path.join(self.out_dir_base, 'full', page_id)
os.makedirs(out_full_img_dir, exist_ok=True)
......@@ -89,25 +94,39 @@ class KaldiDataGenerator:
except ErrorResponse as e:
print("ListTranscriptions failed", e.status_code, e.title, e.content, page_id)
raise e
print("C", count)
print("Num of lines", count)
full_image_url = res['zone']['image']['s3_url']
img = self.get_image(full_image_url, page_id=page_id)
out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name, page_id)
os.makedirs(out_line_img_dir, exist_ok=True)
for i, [x, y, w, h] in enumerate(line_bounding_rects):
cropped = img[y:y + h, x:x + w].copy()
cv2.imwrite(f'{out_line_img_dir}_{i}.jpg', cropped)
cv2.imwrite(f'{self.out_line_img_dir}_{i}.jpg', cropped)
out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name, page_id)
os.makedirs(out_line_text_dir, exist_ok=True)
for i, text in enumerate(line_transcriptions):
write_file(f"{out_line_text_dir}_{i}.txt", text)
write_file(f"{self.out_line_text_dir}_{i}.txt", text)
def run_pages(self, page_ids):
for page_id in page_ids:
print("Page", page_id)
self.extract_lines(page_id)
def run_volumes(self, volume_ids):
for volume_id in volume_ids:
print("Vol", volume_id)
page_ids = [page['id'] for page in api_client.paginate('ListElementChildren', id=volume_id)]
self.run_pages(page_ids)
class KaldiPartitionSplitter:
def __init__(self, out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1):
self.out_dir_base = out_dir_base
self.split_train_ratio = split_train_ratio
self.split_test_ratio = split_test_ratio
def page_level_split(self, line_ids):
page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids})
# page_ids = list({line_id for line_id in line_ids})
random.shuffle(page_ids)
page_count = len(page_ids)
......@@ -141,16 +160,6 @@ class KaldiDataGenerator:
file_name = f"{partitions_dir}/{out_file_dict[i]}Lines.lst"
write_file(file_name, '\n'.join(dataset) + '\n')
def run_pages(self, page_ids):
for page_id in page_ids:
print("P", page_id)
self.extract_lines(page_id)
def run_volumes(self, volume_ids):
for volume_id in volume_ids:
print("V", volume_id)
page_ids = [page['id'] for page in api_client.paginate('ListElementChildren', id=volume_id)]
self.run_pages(page_ids)
example_page_ids = [
......@@ -164,7 +173,8 @@ example_volume_ids = [
]
kaldi_data_generator = KaldiDataGenerator()
kaldi_partitioner = KaldiPartitionSplitter()
# kaldi_data_generator.run_page(example_page_ids)
kaldi_data_generator.run_volumes(example_volume_ids)
kaldi_data_generator.create_partitions()
kaldi_partitioner.create_partitions()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment