From 4a76d32def89002963dfe55f78044c82788d8189 Mon Sep 17 00:00:00 2001 From: Martin <maarand@teklia.com> Date: Mon, 4 Nov 2019 18:25:18 +0100 Subject: [PATCH] refactor, use class --- kaldi_data_generator.py | 258 +++++++++++++++++++++------------------- 1 file changed, 137 insertions(+), 121 deletions(-) diff --git a/kaldi_data_generator.py b/kaldi_data_generator.py index 7e044c3..b7c6126 100644 --- a/kaldi_data_generator.py +++ b/kaldi_data_generator.py @@ -1,24 +1,21 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import glob import os -from pathlib import Path - -import requests +import random from io import BytesIO -from PIL import Image +from pathlib import Path import cv2 - import numpy as np - -import random - +import requests +from PIL import Image from apistar.exceptions import ErrorResponse from arkindex import ArkindexClient, options_from_env + api_client = ArkindexClient(**options_from_env()) + def download_image(url): ''' Download an image and open it with Pillow @@ -33,122 +30,141 @@ def download_image(url): # Preprocess the image and prepare it for classification image = Image.open(BytesIO(resp.content)) print('Downloaded image {} - size={}x{}'.format(url, - image.size[0], - image.size[1])) + image.size[0], + image.size[1])) return image + def write_file(file_name, content): with open(file_name, 'w') as f: f.write(content) -def get_image(image_url, grayscale, out_dir): - out_full_img_dir = os.path.join(out_dir, 'full', page_id) - os.makedirs(out_full_img_dir, exist_ok=True) - out_full_img_path = os.path.join(out_full_img_dir, 'full.jpg') - if grayscale: - download_image(image_url).convert('L').save( - out_full_img_path, format='jpeg') - img = cv2.imread(out_full_img_path, cv2.IMREAD_GRAYSCALE) - else: - download_image(image_url).save( - out_full_img_path, format='jpeg') - img = cv2.imread(out_full_img_path) - return img - -def extract_lines(page_id, grayscale=True, out_dir='/tmp'): - count = 0 - line_bounding_rects = [] - line_polygons = [] - line_transcriptions = [] - try: - for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'): - text = res['text'] - if not text or not text.strip(): - continue - line_transcriptions.append(text) - polygon = res['zone']['polygon'] - line_polygons.append(polygon) - [x, y, w, h] = cv2.boundingRect(np.asarray(polygon)) - line_bounding_rects.append([x, y, w, h]) - count += 1 - except ErrorResponse as e: - print("ListTranscriptions failed", e.status_code, e.title, e.content, page_id) - raise e - - full_image_url = res['zone']['image']['s3_url'] - - img = get_image(full_image_url, grayscale=grayscale, out_dir=out_dir) - - out_line_img_dir = os.path.join(out_dir, 'Lines', page_id) - os.makedirs(out_line_img_dir, exist_ok=True) - for i, [x, y, w, h] in enumerate(line_bounding_rects): - croped = img[y:y + h, x:x + w].copy() -# cv2.imwrite(f'{out_line_img_dir}/{i}.jpg', croped) - cv2.imwrite(f'{out_line_img_dir}_{i}.jpg', croped) - - out_line_text_dir = os.path.join(out_dir, 'Transcriptions', page_id) - os.makedirs(out_line_text_dir, exist_ok=True) - for i, text in enumerate(line_transcriptions): - write_file(f"{out_line_text_dir}_{i}.txt", text) -# write_file(f"{out_line_text_dir}/{i}.txt", text) - - -split_train_ratio = 0.8 -split_test_ratio = 0.1 -split_val_ratio = 1 - split_train_ratio - split_test_ratio - - -def page_level_split(line_ids): -# page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids}) - page_ids = list({line_id for line_id in line_ids}) - random.shuffle(page_ids) - page_count = len(page_ids) - - train_page_ids = page_ids[:round(page_count * split_train_ratio)] - page_ids = page_ids[round(page_count * split_train_ratio):] - - test_page_ids = page_ids[:round(page_count * split_test_ratio)] - page_ids = page_ids[round(page_count * split_test_ratio):] - - val_page_ids = page_ids - - page_dict = {page_id: TRAIN for page_id in train_page_ids} - page_dict.update({page_id: TEST for page_id in test_page_ids}) - page_dict.update({page_id: VAL for page_id in val_page_ids}) - return (train_page_ids, val_page_ids, test_page_ids), page_dict - -TRAIN,TEST,VAL = 0,1,2 -out_file_dict = {0 : 'Train', 1 : 'Test', 2 : 'Validation'} - -def create_partitions(line_ids, out_dir): - (train_page_ids, val_page_ids, test_page_ids), page_dict = page_level_split(line_ids) - datasets = [[] for i in range(3)] - for line_id in line_ids: - page_id = line_id - split_id = page_dict[page_id] - datasets[split_id].append(line_id) - - partitions_dir = os.path.join(out_dir, 'Partitions') - os.makedirs(partitions_dir, exist_ok=True) - for i, dataset in enumerate(datasets): - file_name = f"{partitions_dir}/{out_file_dict[i]}Lines.lst" - with open(file_name, 'w') as f: - f.write('\n'.join(dataset) + '\n') - - -out_dir_base = '/tmp/foo2' - -#page_id = 'bf23cc96-f6b2-4182-923e-6c163db37eba' -page_ids = ['bf23cc96-f6b2-4182-923e-6c163db37eba', - '7c51e648-370e-43b7-9340-3b1a17c13828', - '56521074-59f4-4173-bfc1-4b1384ff8139',] - -for page_id in page_ids: - extract_lines(page_id, out_dir=out_dir_base) - - -lines_path = Path(f'{out_dir_base}/Lines') -line_ids = [str(file.relative_to(lines_path).with_suffix('')) for file in lines_path.glob('**/*.jpg')] - -create_partitions(line_ids, out_dir=out_dir_base) + +TRAIN, TEST, VAL = 0, 1, 2 +out_file_dict = {0: 'Train', 1: 'Test', 2: 'Validation'} + + +class KaldiDataGenerator: + + def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1, + grayscale=True): + self.out_dir_base = out_dir_base + self.dataset_name = dataset_name + self.split_train_ratio = split_train_ratio + self.split_test_ratio = split_test_ratio + self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio + self.grayscale = grayscale + + def get_image(self, image_url, page_id): + out_full_img_dir = os.path.join(self.out_dir_base, 'full', page_id) + os.makedirs(out_full_img_dir, exist_ok=True) + out_full_img_path = os.path.join(out_full_img_dir, 'full.jpg') + if self.grayscale: + download_image(image_url).convert('L').save( + out_full_img_path, format='jpeg') + img = cv2.imread(out_full_img_path, cv2.IMREAD_GRAYSCALE) + else: + download_image(image_url).save( + out_full_img_path, format='jpeg') + img = cv2.imread(out_full_img_path) + return img + + def extract_lines(self, page_id): + count = 0 + line_bounding_rects = [] + line_polygons = [] + line_transcriptions = [] + try: + for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'): + text = res['text'] + if not text or not text.strip(): + continue + line_transcriptions.append(text) + polygon = res['zone']['polygon'] + line_polygons.append(polygon) + [x, y, w, h] = cv2.boundingRect(np.asarray(polygon)) + line_bounding_rects.append([x, y, w, h]) + count += 1 + except ErrorResponse as e: + print("ListTranscriptions failed", e.status_code, e.title, e.content, page_id) + raise e + print("C", count) + full_image_url = res['zone']['image']['s3_url'] + + img = self.get_image(full_image_url, page_id=page_id) + + out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name, page_id) + os.makedirs(out_line_img_dir, exist_ok=True) + for i, [x, y, w, h] in enumerate(line_bounding_rects): + cropped = img[y:y + h, x:x + w].copy() + cv2.imwrite(f'{out_line_img_dir}_{i}.jpg', cropped) + + out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name, page_id) + os.makedirs(out_line_text_dir, exist_ok=True) + for i, text in enumerate(line_transcriptions): + write_file(f"{out_line_text_dir}_{i}.txt", text) + + def page_level_split(self, line_ids): + page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids}) + # page_ids = list({line_id for line_id in line_ids}) + random.shuffle(page_ids) + page_count = len(page_ids) + + train_page_ids = page_ids[:round(page_count * self.split_train_ratio)] + page_ids = page_ids[round(page_count * self.split_train_ratio):] + + test_page_ids = page_ids[:round(page_count * self.split_test_ratio)] + page_ids = page_ids[round(page_count * self.split_test_ratio):] + + val_page_ids = page_ids + + page_dict = {page_id: TRAIN for page_id in train_page_ids} + page_dict.update({page_id: TEST for page_id in test_page_ids}) + page_dict.update({page_id: VAL for page_id in val_page_ids}) + return page_dict + + def create_partitions(self): + lines_path = Path(f'{self.out_dir_base}/Lines') + line_ids = [str(file.relative_to(lines_path).with_suffix('')) for file in lines_path.glob('**/*.jpg')] + + page_dict = self.page_level_split(line_ids) + datasets = [[] for _ in range(3)] + for line_id in line_ids: + page_id = '_'.join(line_id.split('_')[:-1]) + split_id = page_dict[page_id] + datasets[split_id].append(line_id) + + partitions_dir = os.path.join(self.out_dir_base, 'Partitions') + os.makedirs(partitions_dir, exist_ok=True) + for i, dataset in enumerate(datasets): + file_name = f"{partitions_dir}/{out_file_dict[i]}Lines.lst" + write_file(file_name, '\n'.join(dataset) + '\n') + + def run_pages(self, page_ids): + for page_id in page_ids: + print("P", page_id) + self.extract_lines(page_id) + + def run_volumes(self, volume_ids): + for volume_id in volume_ids: + print("V", volume_id) + page_ids = [page['id'] for page in api_client.paginate('ListElementChildren', id=volume_id)] + self.run_pages(page_ids) + + +example_page_ids = [ + 'bf23cc96-f6b2-4182-923e-6c163db37eba', + '7c51e648-370e-43b7-9340-3b1a17c13828', + '56521074-59f4-4173-bfc1-4b1384ff8139', +] + +example_volume_ids = [ + '8f4005e9-1921-47b0-be7b-e27c7fd29486', +] + +kaldi_data_generator = KaldiDataGenerator() + +# kaldi_data_generator.run_page(example_page_ids) +kaldi_data_generator.run_volumes(example_volume_ids) +kaldi_data_generator.create_partitions() -- GitLab