Newer
Older
from apistar.exceptions import ErrorResponse
from arkindex import ArkindexClient, options_from_env
def download_image(url):
'''
Download an image and open it with Pillow
'''
assert url.startswith('http'), 'Image URL must be HTTP(S)'
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
resp = requests.get(url)
resp.raise_for_status()
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content))
print('Downloaded image {} - size={}x{}'.format(url,
def write_file(file_name, content):
with open(file_name, 'w') as f:
f.write(content)
class Extraction(Enum):
boundingRect: int = 0
polygon: int = 1
def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True,
extraction=Extraction.boundingRect):
self.out_dir_base = out_dir_base
self.dataset_name = dataset_name
self.grayscale = grayscale
self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name)
os.makedirs(self.out_line_text_dir, exist_ok=True)
self.out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name)
os.makedirs(self.out_line_img_dir, exist_ok=True)
out_full_img_dir = os.path.join(self.out_dir_base, 'full', page_id)
os.makedirs(out_full_img_dir, exist_ok=True)
out_full_img_path = os.path.join(out_full_img_dir, 'full.jpg')
if self.grayscale:
download_image(image_url).convert('L').save(
out_full_img_path, format='jpeg')
img = cv2.imread(out_full_img_path, cv2.IMREAD_GRAYSCALE)
else:
download_image(image_url).save(
out_full_img_path, format='jpeg')
img = cv2.imread(out_full_img_path)
return img
count = 0
line_bounding_rects = []
line_polygons = []
line_transcriptions = []
try:
for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'):
text = res['text']
if not text or not text.strip():
continue
line_transcriptions.append(text)
polygon = np.asarray(res['zone']['polygon']).clip(0)
line_bounding_rects.append([x, y, w, h])
count += 1
except ErrorResponse as e:
print("ListTranscriptions failed", e.status_code, e.title, e.content, page_id)
raise e
full_image_url = res['zone']['image']['s3_url']
img = self.get_image(full_image_url, page_id=page_id)
if self.extraction_mode == Extraction.boundingRect:
for i, [x, y, w, h] in enumerate(line_bounding_rects):
cropped = img[y:y + h, x:x + w].copy()
cv2.imwrite(f'{self.out_line_img_dir}/{page_id}_{i}.jpg', cropped)
elif self.extraction_mode == Extraction.polygon:
for i, (polygon, rect) in enumerate(zip(line_polygons, line_bounding_rects)):
polygon_img = self.extract_polygon_image(img, polygon=polygon, rect=rect)
cv2.imwrite(f'{self.out_line_img_dir}/{page_id}_{i}.jpg', polygon_img)
else:
raise ValueError("Unsupported extraction mode")
write_file(f"{self.out_line_text_dir}/{page_id}_{i}.txt", text)
def extract_polygon_image(img: 'np.ndarray', polygon: 'np.ndarray', rect: list) -> 'np.ndarray':
pts = polygon.copy()
[x, y, w, h] = rect
cropped = img[y:y + h, x:x + w].copy()
pts = pts - pts.min(axis=0)
mask = np.zeros(cropped.shape[:2], np.uint8)
cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)
dst = cv2.bitwise_and(cropped, cropped, mask=mask)
bg = np.ones_like(cropped, np.uint8) * 255
cv2.bitwise_not(bg, bg, mask=mask)
dst2 = bg + dst
return dst2
for page_id in page_ids:
print("Page", page_id)
self.extract_lines(page_id)
for volume_id in volume_ids:
print("Vol", volume_id)
page_ids = [page['id'] for page in api_client.paginate('ListElementChildren', id=volume_id)]
self.run_pages(page_ids)
class Split(Enum):
Train: int = 0
Test: int = 1
Validation: int = 2
class KaldiPartitionSplitter:
def __init__(self, out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1):
self.out_dir_base = out_dir_base
self.split_train_ratio = split_train_ratio
self.split_test_ratio = split_test_ratio
self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids})
random.shuffle(page_ids)
page_count = len(page_ids)
train_page_ids = page_ids[:round(page_count * self.split_train_ratio)]
page_ids = page_ids[round(page_count * self.split_train_ratio):]
test_page_ids = page_ids[:round(page_count * self.split_test_ratio)]
page_ids = page_ids[round(page_count * self.split_test_ratio):]
val_page_ids = page_ids
page_dict = {page_id: Split.Train.value for page_id in train_page_ids}
page_dict.update({page_id: Split.Test.value for page_id in test_page_ids})
page_dict.update({page_id: Split.Validation.value for page_id in val_page_ids})
lines_path = Path(f'{self.out_dir_base}/Lines')
line_ids = [str(file.relative_to(lines_path).with_suffix('')) for file in lines_path.glob('**/*.jpg')]
page_dict = self.page_level_split(line_ids)
datasets = [[] for _ in range(3)]
for line_id in line_ids:
page_id = '_'.join(line_id.split('_')[:-1])
split_id = page_dict[page_id]
datasets[split_id].append(line_id)
partitions_dir = os.path.join(self.out_dir_base, 'Partitions')
os.makedirs(partitions_dir, exist_ok=True)
for i, dataset in enumerate(datasets):
if not dataset:
print(f"Partition {Split(i).name} is empty! Skipping..")
continue
description="Script to generate Kaldi training data from annotations from Arkindex",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-n', '--dataset_name', type=str, required=True,
help='Name of the dataset being created for kaldi '
'(useful for distinguishing different datasets when in Lines or Transcriptions directory)')
parser.add_argument('-o', '--out_dir', type=str, required=True,
help='output directory')
parser.add_argument('--train_ratio', type=float, default=0.8,
help='Ratio of pages to be used in train (between 0 and 1)')
parser.add_argument('--test_ratio', type=float, default=0.1,
help='Ratio of pages to be used in train (between 0 and 1 - train_ratio)')
parser.add_argument('-e', '--extraction_mode', type=lambda x: Extraction[x], default=Extraction.boundingRect,
help=f'Mode for extracting the line images: {[e.name for e in Extraction]}')
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument('--grayscale', action='store_true',
help='Convert images to grayscale')
group.add_argument('--color', action='store_false',
help='Use color images')
parser.set_defaults(grayscale=True)
parser.add_argument('--volumes', nargs='*',
help='List of volume ids to be used, separated by spaces')
parser.add_argument('--pages', nargs='*',
help='List of page ids to be used, separated by spaces')
return parser
def main():
args = create_parser().parse_args()
print("ARGS", args, '\n')
kaldi_data_generator = KaldiDataGenerator(dataset_name=args.dataset_name,
out_dir_base=args.out_dir,
grayscale=args.grayscale,
extraction=args.extraction_mode)
kaldi_partitioner = KaldiPartitionSplitter(out_dir_base=args.out_dir,
split_train_ratio=args.train_ratio,
split_test_ratio=args.test_ratio)
# extract all the lines and transcriptions
if args.pages:
if args.volumes:
kaldi_data_generator.run_volumes(args.volumes)
print()
# create partitions from all the extracted data
kaldi_partitioner.create_partitions()