#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse import os import random from collections import Counter, defaultdict from enum import Enum from itertools import groupby from pathlib import Path from typing import List import numpy as np import tqdm from apistar.exceptions import ErrorResponse from arkindex import options_from_env from line_image_extractor.extractor import extract, read_img, save_img from line_image_extractor.image_utils import WHITE, Extraction, rotate_and_trim import jsonargparse from kaldi_data_generator.arguments import ( CommonArgs, FilterArgs, ImageArgs, SelectArgs, SplitArgs, Style, ) from kaldi_data_generator.image_utils import download_image, resize_transcription_data from kaldi_data_generator.utils import ( CachedApiClient, TranscriptionData, export_parameters, logger, write_file, ) SEED = 42 random.seed(SEED) MANUAL = "manual" TEXT_LINE = "text_line" DEFAULT_RESCALE = 1.0 STYLE_CLASSES = [el.value for el in Style] ROTATION_CLASSES_TO_ANGLES = { "rotate_0": 0, "rotate_left_90": 90, "rotate_180": 180, "rotate_right_90": -90, } def create_api_client(cache_dir=None): logger.info("Creating API client") # return ArkindexClient(**options_from_env()) return CachedApiClient(cache_root=cache_dir, **options_from_env()) class HTRDataGenerator: def __init__( self, format, dataset_name="my_dataset", out_dir_base="data", image=ImageArgs(), common=CommonArgs(), filter=FilterArgs(), api_client=None, ): self.format = format self.out_dir_base = out_dir_base self.dataset_name = dataset_name self.grayscale = image.grayscale self.extraction_mode = Extraction[image.extraction_mode.value] self.accepted_classes = filter.accepted_classes self.ignored_classes = filter.ignored_classes self.should_filter_by_class = bool(self.accepted_classes) or bool( self.ignored_classes ) self.accepted_worker_version_ids = filter.accepted_worker_version_ids self.should_filter_by_worker = bool(self.accepted_worker_version_ids) self.style = filter.style self.should_filter_by_style = bool(self.style) self.accepted_metadatas = filter.accepted_metadatas self.should_filter_by_metadatas = bool(self.accepted_metadatas) self.transcription_type = filter.transcription_type.value self.skip_vertical_lines = filter.skip_vertical_lines self.skipped_pages_count = 0 self.skipped_vertical_lines_count = 0 self.accepted_lines_count = 0 self.max_deskew_angle = image.max_deskew_angle self.skew_angle = image.skew_angle self.should_rotate = image.should_rotate if image.scale_x or image.scale_y_top or image.scale_y_bottom: self.should_resize_polygons = True # use 1.0 as default - no resize, if not specified self.scale_x = image.scale_x or DEFAULT_RESCALE self.scale_y_top = image.scale_y_top or DEFAULT_RESCALE self.scale_y_bottom = image.scale_y_bottom or DEFAULT_RESCALE else: self.should_resize_polygons = False self.api_client = api_client if MANUAL in self.accepted_worker_version_ids: self.accepted_worker_version_ids[ self.accepted_worker_version_ids.index(MANUAL) ] = None if self.format == "kraken": self.out_line_dir = out_dir_base os.makedirs(self.out_line_dir, exist_ok=True) else: self.out_line_text_dir = os.path.join( self.out_dir_base, "Transcriptions", self.dataset_name ) os.makedirs(self.out_line_text_dir, exist_ok=True) self.out_line_img_dir = os.path.join( self.out_dir_base, "Lines", self.dataset_name ) os.makedirs(self.out_line_img_dir, exist_ok=True) self.cache_dir = Path(common.cache_dir) logger.info(f"Setting up cache to {self.cache_dir}") self.img_cache_dir = self.cache_dir / "images" self.img_cache_dir.mkdir(exist_ok=True, parents=True) if not any(self.img_cache_dir.iterdir()): logger.info("Cache is empty, no need to check") self._cache_is_empty = True else: self._cache_is_empty = False if self.grayscale: self._color = "grayscale" else: self._color = "rgb" def get_image(self, image_url: str, page_id: str) -> "np.ndarray": # id is last part before full/full/0/default.jpg img_id = image_url.split("/")[-5].replace("%2F", "/") cached_img_path = self.img_cache_dir / self._color / img_id if not self._cache_is_empty and cached_img_path.exists(): logger.info(f"Cached image exists: {cached_img_path} - {page_id}") else: logger.info(f"Image not in cache: {cached_img_path} - {page_id}") cached_img_path.parent.mkdir(exist_ok=True, parents=True) pil_img = download_image(image_url) if self.grayscale: pil_img = pil_img.convert("L") pil_img.save(cached_img_path, format="jpeg") img = read_img(cached_img_path, self.grayscale) return img def metadata_filtering(self, elt): metadatas = { metadata["name"]: metadata["value"] for metadata in elt["metadata"] } for meta in self.accepted_metadatas: if not ( meta in metadatas and metadatas[meta] == self.accepted_metadatas[meta] ): return False return True def get_accepted_zones(self, page_id: str): try: accepted_zones = [] for elt in self.api_client.cached_paginate( "ListElementChildren", id=page_id, with_classes=self.should_filter_by_class, with_metadata=self.should_filter_by_metadatas, ): elem_classes = [c for c in elt["classes"] if c["state"] != "rejected"] should_accept = True if self.should_filter_by_class: # at first filter to only have elements with accepted classes # if accepted classes list is empty then should accept all # except for ignored classes should_accept = len(self.accepted_classes) == 0 for classification in elem_classes: class_name = classification["ml_class"]["name"] if class_name in self.accepted_classes: should_accept = True break elif class_name in self.ignored_classes: should_accept = False break if not should_accept: continue if self.should_filter_by_style: style_counts = Counter() for classification in elem_classes: class_name = classification["ml_class"]["name"] if class_name in STYLE_CLASSES: style_counts[class_name] += 1 if len(style_counts) == 0: # no handwritten or typewritten found, so other found_class = Style.other elif len(style_counts) == 1: found_class = list(style_counts.keys())[0] found_class = Style(found_class) else: raise ValueError( f"Multiple style classes on the same element! {elt['id']} - {elem_classes}" ) should_accept = found_class == self.style if not should_accept: continue if self.should_filter_by_metadatas: if self.metadata_filtering(elt): accepted_zones.append(elt["zone"]["id"]) else: accepted_zones.append(elt["zone"]["id"]) logger.info( "Number of accepted zone for page {} : {}".format( page_id, len(accepted_zones) ) ) return accepted_zones except ErrorResponse as e: logger.info( f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}" ) raise e def _validate_transcriptions(self, page_id: str, lines: List[TranscriptionData]): if not lines: return line_elem_counter = Counter([trans.element_id for trans in lines]) most_common = line_elem_counter.most_common(10) if most_common[0][-1] > 1: logger.error("Line elements have multiple transcriptions! Showing top 10:") logger.error(f"{most_common}") raise ValueError(f"Multiple transcriptions: {most_common[0]}") worker_version_counter = Counter([trans.worker_version_id for trans in lines]) if len(worker_version_counter) > 1: logger.warning( f"There are transcriptions from multiple worker versions on this page: {page_id}:" ) logger.warning( f"Top 10 worker versions: {worker_version_counter.most_common(10)}" ) def _choose_best_transcriptions( self, lines: List[TranscriptionData] ) -> List[TranscriptionData]: """ Get the best transcription based on the order of accepted worker version ids. :param lines: :return: """ if not lines: return [] trans_by_element = defaultdict(list) for line in lines: trans_by_element[line.element_id].append(line) best_transcriptions = [] for elem, trans_list in trans_by_element.items(): tmp_dict = {t.worker_version_id: t for t in trans_list} for wv in self.accepted_worker_version_ids: if wv in tmp_dict: best_transcriptions.append(tmp_dict[wv]) break else: logger.info(f"No suitable trans found for {elem}") return best_transcriptions def get_transcriptions(self, page_id: str, accepted_zones): lines = [] try: for res in self.api_client.cached_paginate( "ListTranscriptions", id=page_id, recursive=True ): if ( self.should_filter_by_worker and res["worker_version_id"] not in self.accepted_worker_version_ids ): continue if ( self.should_filter_by_class or self.should_filter_by_style or self.should_filter_by_metadatas ) and (res["element"]["zone"]["id"] not in accepted_zones): continue if res["element"]["type"] != self.transcription_type: continue text = res["text"] if not text or not text.strip(): continue if "\n" in text.strip() and not self.transcription_type == "text": elem_id = res["element"]["id"] raise ValueError( f"Newlines are not allowed in line transcriptions - {page_id} - {elem_id} - {text}" ) if "zone" in res: polygon = res["zone"]["polygon"] elif "element" in res: polygon = res["element"]["zone"]["polygon"] else: raise ValueError(f"Data problem with polygon :: {res}") trans_data = TranscriptionData( element_id=res["element"]["id"], element_name=res["element"]["name"], polygon=polygon, text=text, trans_id=res["id"], worker_version_id=res["worker_version_id"], ) lines.append(trans_data) if self.accepted_worker_version_ids: # if accepted worker versions have been defined then use them lines = self._choose_best_transcriptions(lines) else: # if no accepted worker versions have been defined # then check that there aren't multiple transcriptions # on the same text line self._validate_transcriptions(page_id, lines) if self.should_rotate: classes_by_elem = self.get_children_classes(page_id) for trans in lines: rotation_classes = [ c for c in classes_by_elem[trans.element_id] if c in ROTATION_CLASSES_TO_ANGLES ] if len(rotation_classes) > 0: if len(rotation_classes) > 1: logger.warning( f"Several rotation classes = {len(rotation_classes)} - {trans.element_id}" ) trans.rotation_class = rotation_classes[0] else: logger.warning(f"No rotation classes on {trans.element_id}") count_skipped = 0 if self.skip_vertical_lines: filtered_lines = [] for line in lines: if line.is_vertical: count_skipped += 1 continue filtered_lines.append(line) lines = filtered_lines count = len(lines) return lines, count, count_skipped except ErrorResponse as e: logger.info( f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}" ) raise e def get_children_classes(self, page_id): return { elem["id"]: [ best_class["ml_class"]["name"] for best_class in elem["classes"] if best_class["state"] != "rejected" ] for elem in self.api_client.cached_paginate( "ListElementChildren", id=page_id, recursive=True, type=TEXT_LINE, with_classes=True, ) } def _save_line_image( self, page_id, line_img, manifest_fp=None, trans: TranscriptionData = None ): # Get line id line_id = trans.element_id # Get line number from its name line_number = trans.element_name.split("_")[-1] if self.should_rotate: if trans.rotation_class: rotate_angle = ROTATION_CLASSES_TO_ANGLES[trans.rotation_class] line_img = rotate_and_trim(line_img, rotate_angle, WHITE) if self.format == "kraken": # Save image using the template {page_id}_{line_number}_{line_id} # TODO: check if (0>3) is enough (pad line_number to 3 digits) save_img( f"{self.out_line_dir}/{page_id}_{line_number:0>3}_{line_id}.png", line_img, ) manifest_fp.write(f"{page_id}_{line_number:0>3}_{line_id}.png\n") else: save_img( f"{self.out_line_img_dir}/{page_id}_{line_number:0>3}_{line_id}.jpg", line_img, ) def extract_lines(self, page_id: str, image_data: dict): if ( self.should_filter_by_class or self.should_filter_by_style or self.should_filter_by_metadatas ): accepted_zones = self.get_accepted_zones(page_id) else: accepted_zones = [] lines, count, count_skipped = self.get_transcriptions(page_id, accepted_zones) if count == 0: self.skipped_pages_count += 1 logger.info(f"Page {page_id} skipped, because it has no lines") return logger.debug(f"Total num of lines {count + count_skipped}") logger.debug(f"Num of accepted lines {count}") logger.debug(f"Num of skipped lines {count_skipped}") self.skipped_vertical_lines_count += count_skipped self.accepted_lines_count += count full_image_url = image_data["s3_url"] if full_image_url is None: full_image_url = image_data["url"] + "/full/full/0/default.jpg" img = self.get_image(full_image_url, page_id=page_id) # sort vertically then horizontally sorted_lines = sorted(lines, key=lambda key: (key.rect.y, key.rect.x)) if self.should_resize_polygons: sorted_lines = [ resize_transcription_data( line, image_data["width"], image_data["height"], self.scale_x, self.scale_y_top, self.scale_y_bottom, ) for line in sorted_lines ] if self.format == "kraken": manifest_fp = open(f"{self.out_line_dir}/manifest.txt", "a") # append to file, not re-write it else: # not needed for kaldi manifest_fp = None for trans in sorted_lines: extracted_img = extract( img=img, polygon=trans.polygon, bbox=trans.rect, extraction_mode=self.extraction_mode, max_deskew_angle=self.max_deskew_angle, skew_angle=self.skew_angle, grayscale=self.grayscale, ) # don't enumerate, read the line number from the elements's name (e.g. line_xx) so that it matches with Arkindex self._save_line_image(page_id, extracted_img, manifest_fp, trans) if self.format == "kraken": manifest_fp.close() for trans in sorted_lines: line_number = trans.element_name.split("_")[-1] line_id = trans.element_id if self.format == "kraken": write_file( f"{self.out_line_dir}/{page_id}_{line_number:0>3}_{line_id}.gt.txt", trans.text, ) else: write_file( f"{self.out_line_text_dir}/{page_id}_{line_number:0>3}_{line_id}.txt", trans.text, ) def run_selection(self, select): """ Update select to keep track of selected ids. """ selected_elems = [e for e in self.api_client.paginate("ListSelection")] for elem_type, elems_of_type in groupby( selected_elems, key=lambda x: x["type"] ): elements_ids = [el["id"] for el in elems_of_type] if elem_type == "page": select.pages += elements_ids elif elem_type == "volume": select.volumes += elements_ids elif elem_type == "folder": select.folders += elements_ids else: raise ValueError(f"Unsupported element type {elem_type} in selection!") return select def run_pages(self, pages: list): if all(isinstance(n, str) for n in pages): for page in pages: elt = self.api_client.request("RetrieveElement", id=page) page_id = elt["id"] image_data = elt["zone"]["image"] logger.debug(f"Page {page_id}") self.extract_lines(page_id, image_data) else: for page in tqdm.tqdm(pages): page_id = page["id"] image_data = page["zone"]["image"] logger.debug(f"Page {page_id}") self.extract_lines(page_id, image_data) def run_volumes(self, volume_ids: list): for volume_id in tqdm.tqdm(volume_ids): logger.info(f"Volume {volume_id}") pages = [ page for page in self.api_client.cached_paginate( "ListElementChildren", id=volume_id, recursive=True, type="page" ) ] self.run_pages(pages) def run_folders(self, element_ids: list, volume_type: str): for elem_id in tqdm.tqdm(element_ids): logger.info(f"Folder {elem_id}") vol_ids = [ page["id"] for page in self.api_client.cached_paginate( "ListElementChildren", id=elem_id, recursive=True, type=volume_type ) ] self.run_volumes(vol_ids) def run_corpora(self, corpus_ids: list, volume_type: str): for corpus_id in tqdm.tqdm(corpus_ids): logger.info(f"Corpus {corpus_id}") vol_ids = [ vol["id"] for vol in self.api_client.cached_paginate( "ListElements", corpus=corpus_id, type=volume_type ) ] self.run_volumes(vol_ids) class Split(Enum): Train: str = "train" Test: str = "test" Validation: str = "val" class KaldiPartitionSplitter: def __init__( self, out_dir_base="/tmp/kaldi_data", split_train_ratio=0.8, split_val_ratio=0.1, split_test_ratio=0.1, use_existing_split=False, ): self.out_dir_base = out_dir_base self.split_train_ratio = split_train_ratio self.split_test_ratio = split_test_ratio self.split_val_ratio = split_val_ratio self.use_existing_split = use_existing_split def page_level_split(self, line_ids: list) -> dict: """ Split pages into train, validation and test subsets. Don't split lines to avoid data leakage. line_ids (list): a list of line ids named {page_id}_{line_number}_{line_id} """ # Get page ids from line ids to create splits at page level page_ids = ["_".join(line_id.split("_")[:-2]) for line_id in line_ids] # Remove duplicates and sort for reproducibility page_ids = sorted(set(page_ids)) random.Random(SEED).shuffle(page_ids) page_count = len(page_ids) # Use np.split to split in three sets stop_train_idx = round(page_count * self.split_train_ratio) stop_val_idx = stop_train_idx + round(page_count * self.split_val_ratio) train_page_ids, val_page_ids, test_page_ids = np.split( page_ids, [stop_train_idx, stop_val_idx] ) # Build dictionary that will be used to split lines {id: split} page_dict = {page_id: Split.Train.value for page_id in train_page_ids} page_dict.update({page_id: Split.Validation.value for page_id in val_page_ids}) page_dict.update({page_id: Split.Test.value for page_id in test_page_ids}) return page_dict def existing_split(self, line_ids: list) -> list: """ Expect line_ids to be named {split}/{path_to_image} where split in ["train", "val", "test"] """ split_dict = {split: [] for split in Split} for line_id in line_ids: split_prefix = line_id.split("/")[0].lower() split_dict[split_prefix].append(line_id) return split_dict def create_partitions(self): """ """ logger.info(f"Creating {[split.value for split in Split]} partitions") # Get all images ids (and remove extension) lines_path = Path(f"{self.out_dir_base}/Lines") line_ids = [ str(file.relative_to(lines_path).with_suffix("")) for file in sorted(lines_path.glob("**/*.jpg")) ] if self.use_existing_split: logger.info("Using existing split") datasets = self.existing_split(line_ids) else: page_dict = self.page_level_split(line_ids) # extend this split for lines datasets = {s.value: [] for s in Split} for line_id in line_ids: page_id = "_".join(line_id.split("_")[:-2]) split_id = page_dict[page_id] datasets[split_id].append(line_id) partitions_dir = os.path.join(self.out_dir_base, "Partitions") os.makedirs(partitions_dir, exist_ok=True) for split, split_line_ids in datasets.items(): if not split_line_ids: logger.info(f"Partition {split} is empty! Skipping...") continue file_name = f"{partitions_dir}/{Split(split).name}Lines.lst" write_file(file_name, "\n".join(split_line_ids) + "\n") return datasets def run(format, dataset_name, out_dir, common, image, split, select, filter): api_client = create_api_client(Path(common.cache_dir)) if not split.split_only: data_generator = HTRDataGenerator( format=format, dataset_name=dataset_name, out_dir_base=out_dir, common=common, image=image, filter=filter, api_client=api_client, ) # extract all the lines and transcriptions if select.selection: select = data_generator.run_selection(select) if select.pages: data_generator.run_pages(select.pages) if select.volumes: data_generator.run_volumes(select.volumes) if select.folders: data_generator.run_folders(select.folders, select.volume_type) if select.corpora: data_generator.run_corpora(select.corpora, select.volume_type) if data_generator.skipped_vertical_lines_count > 0: logger.info( f"Number of skipped pages: {data_generator.skipped_pages_count}" ) _skipped_vertical_count = data_generator.skipped_vertical_lines_count _total_count = _skipped_vertical_count + data_generator.accepted_lines_count skipped_ratio = _skipped_vertical_count / _total_count * 100 logger.info( f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({round(skipped_ratio, 2)}%)" ) else: logger.info("Creating a split from already downloaded files") data_generator = None if not split.no_split: kaldi_partitioner = KaldiPartitionSplitter( out_dir_base=out_dir, split_train_ratio=split.train_ratio, split_val_ratio=split.val_ratio, split_test_ratio=split.test_ratio, use_existing_split=split.use_existing_split, ) # create partitions from all the extracted data datasets = kaldi_partitioner.create_partitions() else: logger.info("No split to be done") datasets = {} logger.info("DONE") export_parameters( format, dataset_name, out_dir, common, image, split, select, filter, datasets, arkindex_api_url=options_from_env()["base_url"], ) logger.warning( f"Consider cleaning your cache directory {common.cache_dir} if you are done." ) def get_args(): parser = jsonargparse.ArgumentParser( description="Script to generate Kaldi or kraken training data from annotations from Arkindex", formatter_class=argparse.ArgumentDefaultsHelpFormatter, parse_as_dict=True, ) parser.add_argument( "--config", action=jsonargparse.ActionConfigFile, help="Configuration file" ) parser.add_argument( "-f", "--format", type=str, required=True, help="is the data generated going to be used for kaldi or kraken", ) parser.add_argument( "-d", "--dataset_name", type=str, required=True, help="Name of the dataset being created for kaldi or kraken " "(useful for distinguishing different datasets when in Lines or Transcriptions directory)", ) parser.add_argument( "-o", "--out_dir", type=str, required=True, help="output directory" ) parser.add_class_arguments(CommonArgs, "common") parser.add_class_arguments(ImageArgs, "image") parser.add_class_arguments(SplitArgs, "split") parser.add_class_arguments(FilterArgs, "filter") parser.add_class_arguments(SelectArgs, "select") args = parser.parse_args(with_meta=False) args["common"] = CommonArgs(**args["common"]) args["image"] = ImageArgs(**args["image"]) args["split"] = SplitArgs(**args["split"]) args["select"] = SelectArgs(**args["select"]) args["filter"] = FilterArgs(**args["filter"]) # Check overlap of accepted and ignored classes accepted_classes = args["filter"].accepted_classes ignored_classes = args["filter"].accepted_classes if accepted_classes and ignored_classes: if set(accepted_classes) & set(ignored_classes): parser.error( f"--filter.accepted_classes and --filter.ignored_classes values must not overlap ({accepted_classes} - {ignored_classes})" ) if args["filter"].style and (accepted_classes or ignored_classes): if set(STYLE_CLASSES) & (set(accepted_classes) | set(ignored_classes)): parser.error( f"--style class values ({STYLE_CLASSES}) shouldn't be in the accepted_classes list " f"(or ignored_classes list) " "if --filter.style is used with either --filter.accepted_classes or --filter.ignored_classes." ) del args["config"] return args def main(): args = get_args() logger.info(f"Arguments: {args} \n") run(**args) if __name__ == "__main__": main()