Newer
Older
from apistar.exceptions import ErrorResponse
from arkindex import ArkindexClient, options_from_env
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s/%(name)s: %(message)s"
)
logger = logging.getLogger(os.path.basename(__file__))
def download_image(url):
'''
Download an image and open it with Pillow
'''
assert url.startswith('http'), 'Image URL must be HTTP(S)'
# Download the image
# Cannot use stream=True as urllib's responses do not support the seek(int) method,
# which is explicitly required by Image.open on file-like objects
resp = requests.get(url)
resp.raise_for_status()
# Preprocess the image and prepare it for classification
image = Image.open(BytesIO(resp.content))
logger.debug('Downloaded image {} - size={}x{}'.format(url,
def write_file(file_name, content):
with open(file_name, 'w') as f:
f.write(content)
class Extraction(Enum):
boundingRect: int = 0
polygon: int = 1
def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True,
extraction=Extraction.boundingRect, accepted_slugs=None, accepted_classes=None, filter_printed=False):
self.out_dir_base = out_dir_base
self.dataset_name = dataset_name
self.grayscale = grayscale
self.accepted_slugs = accepted_slugs
self.should_filter_by_slug = bool(self.accepted_slugs)
self.accepted_classes = accepted_classes
self.should_filter_by_class = bool(self.accepted_classes)
self.should_filter_printed = filter_printed
self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name)
os.makedirs(self.out_line_text_dir, exist_ok=True)
self.out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name)
os.makedirs(self.out_line_img_dir, exist_ok=True)
out_full_img_dir = os.path.join(self.out_dir_base, 'full', page_id)
os.makedirs(out_full_img_dir, exist_ok=True)
out_full_img_path = os.path.join(out_full_img_dir, 'full.jpg')
if self.grayscale:
download_image(image_url).convert('L').save(
out_full_img_path, format='jpeg')
img = cv2.imread(out_full_img_path, cv2.IMREAD_GRAYSCALE)
else:
download_image(image_url).save(
out_full_img_path, format='jpeg')
img = cv2.imread(out_full_img_path)
return img
def extract_lines(self, page_id: str, image_data: dict):
if self.should_filter_by_class:
accepted_zones = []
for elt in api_client.paginate('ListElementChildren', id=page_id, with_best_classes=True):
printed = True
for classification in elt['best_classes']:
if classification['ml_class']['name'] == 'handwritten':
printed = False
for classification in elt['best_classes']:
if classification['ml_class']['name'] in self.accepted_classes:
if self.should_filter_printed:
if not printed:
accepted_zones.append(elt['zone']['id'])
else:
accepted_zones.append(elt['zone']['id'])
logger.info('Number of accepted zone for page {} : {}'.format(page_id, len(accepted_zones)))
for res in api_client.paginate('ListTranscriptions', id=page_id, type='line', recursive=True):
if self.should_filter_by_slug and res['source']['slug'] not in self.accepted_slugs:
continue
if self.should_filter_by_class and res['zone']['id'] not in accepted_zones:
continue
text = res['text']
if not text or not text.strip():
continue
if res['zone']:
polygon = res['zone']['polygon']
elif res['element']['zone']:
polygon = res['element']['zone']['polygon']
else:
raise ValueError(f"Data problem with polygon :: {res}")
polygon = np.asarray(polygon).clip(0)
logger.info(f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}")
full_image_url = image_data['url'] + '/full/full/0/default.jpg'
img = self.get_image(full_image_url, page_id=page_id)
# sort vertically then horizontally
sorted_lines = sorted(lines, key=lambda key: (key[0][1], key[0][0]))
if self.extraction_mode == Extraction.boundingRect:
for i, ((x, y, w, h), polygon, text) in enumerate(sorted_lines):
cropped = img[y:y + h, x:x + w].copy()
cv2.imwrite(f'{self.out_line_img_dir}/{page_id}_{i}.jpg', cropped)
elif self.extraction_mode == Extraction.polygon:
for i, (rect, polygon, text) in enumerate(sorted_lines):
polygon_img = self.extract_polygon_image(img, polygon=polygon, rect=rect)
cv2.imwrite(f'{self.out_line_img_dir}/{page_id}_{i}.jpg', polygon_img)
else:
raise ValueError("Unsupported extraction mode")
for i, (rect, polygon, text) in enumerate(sorted_lines):
write_file(f"{self.out_line_text_dir}/{page_id}_{i}.txt", text)
def extract_polygon_image(img: 'np.ndarray', polygon: 'np.ndarray', rect: Box) -> 'np.ndarray':
pts = polygon.copy()
[x, y, w, h] = rect
cropped = img[y:y + h, x:x + w].copy()
pts = pts - pts.min(axis=0)
mask = np.zeros(cropped.shape[:2], np.uint8)
cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)
dst = cv2.bitwise_and(cropped, cropped, mask=mask)
bg = np.ones_like(cropped, np.uint8) * 255
cv2.bitwise_not(bg, bg, mask=mask)
dst2 = bg + dst
return dst2
def run_pages(self, pages: list):
for page in tqdm.tqdm(pages):
page_id = page['id']
image_data = page['zone']['image']
pages = [page for page in api_client.paginate('ListElementChildren', id=volume_id)]
self.run_pages(pages)
def run_folders(self, element_ids: list, volume_type: str):
for elem_id in tqdm.tqdm(element_ids):
logger.info(f"Folder {elem_id}")
vol_ids = [page['id'] for page in
api_client.paginate('ListElementChildren', id=elem_id, recursive=True, type=volume_type)]
self.run_volumes(vol_ids)
def run_corpora(self, corpus_ids: list, volume_type: str):
vol_ids = [vol['id'] for vol in api_client.paginate('ListElements', corpus=corpus_id, type=volume_type)]
class Split(Enum):
Train: int = 0
Test: int = 1
Validation: int = 2
Martin
committed
@property
def short_name(self) -> str:
if self == self.Validation:
return "val"
return self.name.lower()
def __init__(self, out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1,
use_existing_split=False):
self.out_dir_base = out_dir_base
self.split_train_ratio = split_train_ratio
self.split_test_ratio = split_test_ratio
self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
Martin
committed
self.use_existing_split = use_existing_split
page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids})
random.shuffle(page_ids)
page_count = len(page_ids)
train_page_ids = page_ids[:round(page_count * self.split_train_ratio)]
page_ids = page_ids[round(page_count * self.split_train_ratio):]
test_page_ids = page_ids[:round(page_count * self.split_test_ratio)]
page_ids = page_ids[round(page_count * self.split_test_ratio):]
val_page_ids = page_ids
page_dict = {page_id: Split.Train.value for page_id in train_page_ids}
page_dict.update({page_id: Split.Test.value for page_id in test_page_ids})
page_dict.update({page_id: Split.Validation.value for page_id in val_page_ids})
Martin
committed
def existing_split(self, line_ids: list) -> list:
split_dict = {split.short_name: [] for split in Split}
for line_id in line_ids:
split_prefix = line_id.split('/')[0].lower()
split_dict[split_prefix].append(line_id)
splits = [split_dict[split.short_name] for split in Split]
return splits
lines_path = Path(f'{self.out_dir_base}/Lines')
line_ids = [str(file.relative_to(lines_path).with_suffix('')) for file in lines_path.glob('**/*.jpg')]
Martin
committed
if self.use_existing_split:
logger.info("Using existing split")
datasets = self.existing_split(line_ids)
else:
page_dict = self.page_level_split(line_ids)
datasets = [[] for _ in range(3)]
for line_id in line_ids:
page_id = '_'.join(line_id.split('_')[:-1])
split_id = page_dict[page_id]
datasets[split_id].append(line_id)
partitions_dir = os.path.join(self.out_dir_base, 'Partitions')
os.makedirs(partitions_dir, exist_ok=True)
for i, dataset in enumerate(datasets):
description="Script to generate Kaldi training data from annotations from Arkindex",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-n', '--dataset_name', type=str,
help='Name of the dataset being created for kaldi '
'(useful for distinguishing different datasets when in Lines or Transcriptions directory)')
parser.add_argument('-o', '--out_dir', type=str, required=True,
help='output directory')
parser.add_argument('--train_ratio', type=float, default=0.8,
help='Ratio of pages to be used in train (between 0 and 1)')
parser.add_argument('--test_ratio', type=float, default=0.1,
help='Ratio of pages to be used in test (between 0 and 1 - train_ratio)')
Martin
committed
parser.add_argument('--use_existing_split', action='store_true', default=False,
help='Use an existing split instead of random. '
'Expecting line_ids to be prefixed with (train, val and test)')
parser.add_argument('--split_only', '--no_download', action='store_true', default=False,
help="Vreate the split from already downloaded lines, don't download the lines")
parser.add_argument('-e', '--extraction_mode', type=lambda x: Extraction[x], default=Extraction.boundingRect,
help=f'Mode for extracting the line images: {[e.name for e in Extraction]}')
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument('--grayscale', action='store_true',
help='Convert images to grayscale')
group.add_argument('--color', action='store_false',
help='Use color images')
parser.set_defaults(grayscale=True)
parser.add_argument('--corpora', nargs='*',
help='List of corpus ids to be used, separated by spaces')
parser.add_argument('--folders', type=str, nargs='*',
help='List of folder ids to be used, separated by spaces. '
'Elements of `volume_type` will be searched recursively in these folders')
parser.add_argument('--volumes', nargs='*',
help='List of volume ids to be used, separated by spaces')
# parser.add_argument('--pages', nargs='*',
# help='List of page ids to be used, separated by spaces')
parser.add_argument('-v', '--volume_type', type=str, default='volume',
help='Volumes (1 level above page) may have a different name on corpora')
parser.add_argument('--accepted_slugs', nargs='*',
help='List of accepted slugs for downloading transcriptions')
Martin
committed
parser.add_argument('--accepted_classes', nargs='*',
help='List of accepted ml_class names. Filter lines by class of related elements')
parser.add_argument('--filter_printed', action='store_true',
parser = create_parser()
args = parser.parse_args()
if not args.dataset_name and not args.split_only:
parser.error("--dataset_name must be specified (unless --split-only)")
Martin
committed
if not args.split_only:
kaldi_data_generator = KaldiDataGenerator(
dataset_name=args.dataset_name,
out_dir_base=args.out_dir,
grayscale=args.grayscale,
extraction=args.extraction_mode,
accepted_slugs=args.accepted_slugs,
accepted_classes=args.accepted_classes,
filter_printed=args.filter_printed)
Martin
committed
# extract all the lines and transcriptions
# if args.pages:
# kaldi_data_generator.run_pages(args.pages)
if args.volumes:
kaldi_data_generator.run_volumes(args.volumes)
if args.folders:
kaldi_data_generator.run_folders(args.folders, args.volume_type)
if args.corpora:
kaldi_data_generator.run_corpora(args.corpora, args.volume_type)
else:
logger.info("Creating a split from already downloaded files")
kaldi_partitioner = KaldiPartitionSplitter(
out_dir_base=args.out_dir,
split_train_ratio=args.train_ratio,
split_test_ratio=args.test_ratio,
use_existing_split=args.use_existing_split)
# create partitions from all the extracted data
kaldi_partitioner.create_partitions()