Newer
Older
from apistar.exceptions import ErrorResponse
from arkindex import ArkindexClient, options_from_env
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s/%(name)s: %(message)s"
)
logger = logging.getLogger(os.path.basename(__file__))
def download_image(url):
'''
Download an image and open it with Pillow
'''
assert url.startswith('http'), 'Image URL must be HTTP(S)'
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
resp = requests.get(url)
resp.raise_for_status()
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content))
logger.debug('Downloaded image {} - size={}x{}'.format(url,
def write_file(file_name, content):
with open(file_name, 'w') as f:
f.write(content)
class Extraction(Enum):
boundingRect: int = 0
polygon: int = 1
def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True,
extraction=Extraction.boundingRect, accepted_slugs=None, accepted_classes=None, filter_printed=False):
self.out_dir_base = out_dir_base
self.dataset_name = dataset_name
self.grayscale = grayscale
self.accepted_slugs = accepted_slugs
self.should_filter_by_slug = bool(self.accepted_slugs)
self.accepted_classes = accepted_classes
self.should_filter_by_class = bool(self.accepted_classes)
self.should_filter_printed = filter_printed
self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name)
os.makedirs(self.out_line_text_dir, exist_ok=True)
self.out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name)
os.makedirs(self.out_line_img_dir, exist_ok=True)
out_full_img_dir = os.path.join(self.out_dir_base, 'full', page_id)
os.makedirs(out_full_img_dir, exist_ok=True)
out_full_img_path = os.path.join(out_full_img_dir, 'full.jpg')
if self.grayscale:
download_image(image_url).convert('L').save(
out_full_img_path, format='jpeg')
img = cv2.imread(out_full_img_path, cv2.IMREAD_GRAYSCALE)
else:
download_image(image_url).save(
out_full_img_path, format='jpeg')
img = cv2.imread(out_full_img_path)
return img
if self.should_filter_by_class:
accepted_zones = []
for elt in api_client.paginate('ListElementChildren',id=page_id, with_best_classes=True):
printed = True
for classification in elt['best_classes']:
if classification['ml_class']['name'] == 'handwritten':
printed = False
for classification in elt['best_classes']:
if classification['ml_class']['name'] in self.accepted_classes:
if self.should_filter_printed:
if not printed:
accepted_zones.append(elt['zone']['id'])
else:
accepted_zones.append(elt['zone']['id'])
logger.info('Number of accepted zone for page {} : {}'.format(page_id,len(accepted_zones)))
for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'):
if self.should_filter_by_slug and res['source']['slug'] not in self.accepted_slugs:
continue
if self.should_filter_by_class and res['zone']['id'] not in accepted_zones:
continue
text = res['text']
if not text or not text.strip():
continue
polygon = np.asarray(res['zone']['polygon']).clip(0)
[x, y, w, h] = cv2.boundingRect(polygon)
logger.info(f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}")
full_image_url = res['zone']['image']['url'] + '/full/full/0/default.jpg'
img = self.get_image(full_image_url, page_id=page_id)
# sort vertically then horizontally
sorted_lines = sorted(lines, key=lambda key: (key[0][1], key[0][0]))
if self.extraction_mode == Extraction.boundingRect:
for i, ((x, y, w, h), polygon, text) in enumerate(sorted_lines):
cropped = img[y:y + h, x:x + w].copy()
cv2.imwrite(f'{self.out_line_img_dir}/{page_id}_{i}.jpg', cropped)
elif self.extraction_mode == Extraction.polygon:
for i, (rect, polygon, text) in enumerate(sorted_lines):
polygon_img = self.extract_polygon_image(img, polygon=polygon, rect=rect)
cv2.imwrite(f'{self.out_line_img_dir}/{page_id}_{i}.jpg', polygon_img)
else:
raise ValueError("Unsupported extraction mode")
for i, (rect, polygon, text) in enumerate(sorted_lines):
write_file(f"{self.out_line_text_dir}/{page_id}_{i}.txt", text)
def extract_polygon_image(img: 'np.ndarray', polygon: 'np.ndarray', rect: Box) -> 'np.ndarray':
pts = polygon.copy()
[x, y, w, h] = rect
cropped = img[y:y + h, x:x + w].copy()
pts = pts - pts.min(axis=0)
mask = np.zeros(cropped.shape[:2], np.uint8)
cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)
dst = cv2.bitwise_and(cropped, cropped, mask=mask)
bg = np.ones_like(cropped, np.uint8) * 255
cv2.bitwise_not(bg, bg, mask=mask)
dst2 = bg + dst
return dst2
for page_id in tqdm.tqdm(page_ids):
logger.debug(f"Page {page_id}")
page_ids = [page['id'] for page in api_client.paginate('ListElementChildren', id=volume_id)]
self.run_pages(page_ids)
def run_folders(self, element_ids: list, volume_type: str):
for elem_id in tqdm.tqdm(element_ids):
logger.info(f"Folder {elem_id}")
vol_ids = [page['id'] for page in
api_client.paginate('ListElementChildren', id=elem_id, recursive=True, type=volume_type)]
self.run_volumes(vol_ids)
def run_corpora(self, corpus_ids: list, volume_type: str):
vol_ids = [vol['id'] for vol in api_client.paginate('ListElements', corpus=corpus_id, type=volume_type)]
class Split(Enum):
Train: int = 0
Test: int = 1
Validation: int = 2
class KaldiPartitionSplitter:
def __init__(self, out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1):
self.out_dir_base = out_dir_base
self.split_train_ratio = split_train_ratio
self.split_test_ratio = split_test_ratio
self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids})
random.shuffle(page_ids)
page_count = len(page_ids)
train_page_ids = page_ids[:round(page_count * self.split_train_ratio)]
page_ids = page_ids[round(page_count * self.split_train_ratio):]
test_page_ids = page_ids[:round(page_count * self.split_test_ratio)]
page_ids = page_ids[round(page_count * self.split_test_ratio):]
val_page_ids = page_ids
page_dict = {page_id: Split.Train.value for page_id in train_page_ids}
page_dict.update({page_id: Split.Test.value for page_id in test_page_ids})
page_dict.update({page_id: Split.Validation.value for page_id in val_page_ids})
lines_path = Path(f'{self.out_dir_base}/Lines')
line_ids = [str(file.relative_to(lines_path).with_suffix('')) for file in lines_path.glob('**/*.jpg')]
page_dict = self.page_level_split(line_ids)
datasets = [[] for _ in range(3)]
for line_id in line_ids:
page_id = '_'.join(line_id.split('_')[:-1])
split_id = page_dict[page_id]
datasets[split_id].append(line_id)
partitions_dir = os.path.join(self.out_dir_base, 'Partitions')
os.makedirs(partitions_dir, exist_ok=True)
for i, dataset in enumerate(datasets):
description="Script to generate Kaldi training data from annotations from Arkindex",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-n', '--dataset_name', type=str, required=True,
help='Name of the dataset being created for kaldi '
'(useful for distinguishing different datasets when in Lines or Transcriptions directory)')
parser.add_argument('-o', '--out_dir', type=str, required=True,
help='output directory')
parser.add_argument('--train_ratio', type=float, default=0.8,
help='Ratio of pages to be used in train (between 0 and 1)')
parser.add_argument('--test_ratio', type=float, default=0.1,
help='Ratio of pages to be used in train (between 0 and 1 - train_ratio)')
parser.add_argument('-e', '--extraction_mode', type=lambda x: Extraction[x], default=Extraction.boundingRect,
help=f'Mode for extracting the line images: {[e.name for e in Extraction]}')
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument('--grayscale', action='store_true',
help='Convert images to grayscale')
group.add_argument('--color', action='store_false',
help='Use color images')
parser.set_defaults(grayscale=True)
parser.add_argument('--corpora', nargs='*',
help='List of corpus ids to be used, separated by spaces')
parser.add_argument('--folders', type=str, nargs='*',
help='List of folder ids to be used, separated by spaces. '
'Elements of `volume_type` will be searched recursively in these folders')
parser.add_argument('--volumes', nargs='*',
help='List of volume ids to be used, separated by spaces')
parser.add_argument('--pages', nargs='*',
help='List of page ids to be used, separated by spaces')
parser.add_argument('-v', '--volume_type', type=str, default='volume',
help='Volumes (1 level above page) may have a different name on corpora')
parser.add_argument('--accepted_slugs', nargs='*',
help='List of accepted slugs for downloading transcriptions')
parser.add_argument('--accepted_classes', nargs='*',
help='List of accepted ml_class names. Filter lines by class of related elements')
parser.add_argument('--filter_printed', action='store_true',
help='Filter lines annotated as printed')
return parser
def main():
args = create_parser().parse_args()
kaldi_data_generator = KaldiDataGenerator(dataset_name=args.dataset_name,
out_dir_base=args.out_dir,
accepted_slugs=args.accepted_slugs,
accepted_classes=args.accepted_classes,
filter_printed=args.filter_printed)
kaldi_partitioner = KaldiPartitionSplitter(out_dir_base=args.out_dir,
split_train_ratio=args.train_ratio,
split_test_ratio=args.test_ratio)
# extract all the lines and transcriptions
if args.pages:
if args.volumes:
kaldi_data_generator.run_volumes(args.volumes)
if args.folders:
kaldi_data_generator.run_folders(args.folders, args.volume_type)
kaldi_data_generator.run_corpora(args.corpora, args.volume_type)
# create partitions from all the extracted data
kaldi_partitioner.create_partitions()