Newer
Older
from apistar.exceptions import ErrorResponse
from arkindex import ArkindexClient, options_from_env
def download_image(url):
'''
Download an image and open it with Pillow
'''
assert url.startswith('http'), 'Image URL must be HTTP(S)'
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
resp = requests.get(url)
resp.raise_for_status()
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content))
print('Downloaded image {} - size={}x{}'.format(url,
def write_file(file_name, content):
with open(file_name, 'w') as f:
f.write(content)
class KaldiDataGenerator:
def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1,
grayscale=True):
self.out_dir_base = out_dir_base
self.dataset_name = dataset_name
self.split_train_ratio = split_train_ratio
self.split_test_ratio = split_test_ratio
self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
self.grayscale = grayscale
self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name)
os.makedirs(self.out_line_text_dir, exist_ok=True)
self.out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name)
os.makedirs(self.out_line_img_dir, exist_ok=True)
def get_image(self, image_url, page_id):
out_full_img_dir = os.path.join(self.out_dir_base, 'full', page_id)
os.makedirs(out_full_img_dir, exist_ok=True)
out_full_img_path = os.path.join(out_full_img_dir, 'full.jpg')
if self.grayscale:
download_image(image_url).convert('L').save(
out_full_img_path, format='jpeg')
img = cv2.imread(out_full_img_path, cv2.IMREAD_GRAYSCALE)
else:
download_image(image_url).save(
out_full_img_path, format='jpeg')
img = cv2.imread(out_full_img_path)
return img
def extract_lines(self, page_id):
count = 0
line_bounding_rects = []
line_polygons = []
line_transcriptions = []
try:
for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'):
text = res['text']
if not text or not text.strip():
continue
line_transcriptions.append(text)
polygon = res['zone']['polygon']
line_polygons.append(polygon)
[x, y, w, h] = cv2.boundingRect(np.asarray(polygon))
line_bounding_rects.append([x, y, w, h])
count += 1
except ErrorResponse as e:
print("ListTranscriptions failed", e.status_code, e.title, e.content, page_id)
raise e
full_image_url = res['zone']['image']['s3_url']
img = self.get_image(full_image_url, page_id=page_id)
for i, [x, y, w, h] in enumerate(line_bounding_rects):
cropped = img[y:y + h, x:x + w].copy()
cv2.imwrite(f'{self.out_line_img_dir}_{i}.jpg', cropped)
write_file(f"{self.out_line_text_dir}_{i}.txt", text)
def run_pages(self, page_ids):
for page_id in page_ids:
print("Page", page_id)
self.extract_lines(page_id)
def run_volumes(self, volume_ids):
for volume_id in volume_ids:
print("Vol", volume_id)
page_ids = [page['id'] for page in api_client.paginate('ListElementChildren', id=volume_id)]
self.run_pages(page_ids)
class Split(Enum):
Train: int = 0
Test: int = 1
Validation: int = 2
class KaldiPartitionSplitter:
def __init__(self, out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1):
self.out_dir_base = out_dir_base
self.split_train_ratio = split_train_ratio
self.split_test_ratio = split_test_ratio
def page_level_split(self, line_ids):
page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids})
random.shuffle(page_ids)
page_count = len(page_ids)
train_page_ids = page_ids[:round(page_count * self.split_train_ratio)]
page_ids = page_ids[round(page_count * self.split_train_ratio):]
test_page_ids = page_ids[:round(page_count * self.split_test_ratio)]
page_ids = page_ids[round(page_count * self.split_test_ratio):]
val_page_ids = page_ids
page_dict = {page_id: Split.Train.value for page_id in train_page_ids}
page_dict.update({page_id: Split.Test.value for page_id in test_page_ids})
page_dict.update({page_id: Split.Validation.value for page_id in val_page_ids})
lines_path = Path(f'{self.out_dir_base}/Lines')
line_ids = [str(file.relative_to(lines_path).with_suffix('')) for file in lines_path.glob('**/*.jpg')]
page_dict = self.page_level_split(line_ids)
datasets = [[] for _ in range(3)]
for line_id in line_ids:
page_id = '_'.join(line_id.split('_')[:-1])
split_id = page_dict[page_id]
datasets[split_id].append(line_id)
partitions_dir = os.path.join(self.out_dir_base, 'Partitions')
os.makedirs(partitions_dir, exist_ok=True)
for i, dataset in enumerate(datasets):
write_file(file_name, '\n'.join(dataset) + '\n')
example_page_ids = [
'bf23cc96-f6b2-4182-923e-6c163db37eba',
'7c51e648-370e-43b7-9340-3b1a17c13828',
'56521074-59f4-4173-bfc1-4b1384ff8139',
]
example_volume_ids = [
'8f4005e9-1921-47b0-be7b-e27c7fd29486',
]
kaldi_data_generator = KaldiDataGenerator()
# kaldi_data_generator.run_page(example_page_ids)
kaldi_data_generator.run_volumes(example_volume_ids)