diff --git a/kaldi_data_generator.py b/kaldi_data_generator.py index 736f7151746165cb3a08601c259d19b860bbcff2..644091048f1bd332c897a04ddc81dbfcb7e3920e 100644 --- a/kaldi_data_generator.py +++ b/kaldi_data_generator.py @@ -94,7 +94,7 @@ class KaldiDataGenerator: img = cv2.imread(out_full_img_path) return img - def extract_lines(self, page_id: str): + def extract_lines(self, page_id: str, image_data: dict): count = 0 lines = [] try: @@ -114,7 +114,7 @@ class KaldiDataGenerator: accepted_zones.append(elt['zone']['id']) logger.info('Number of accepted zone for page {} : {}'.format(page_id,len(accepted_zones))) - for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'): + for res in api_client.paginate('ListTranscriptions', id=page_id, type='line', recursive=True): if self.should_filter_by_slug and res['source']['slug'] not in self.accepted_slugs: continue @@ -124,7 +124,15 @@ class KaldiDataGenerator: text = res['text'] if not text or not text.strip(): continue - polygon = np.asarray(res['zone']['polygon']).clip(0) + + if res['zone']: + polygon = res['zone']['polygon'] + elif res['element']['zone']: + polygon = res['element']['zone']['polygon'] + else: + raise ValueError(f"Data problem with polygon :: {res}") + + polygon = np.asarray(polygon).clip(0) [x, y, w, h] = cv2.boundingRect(polygon) lines.append(((x, y, w, h), polygon, text)) count += 1 @@ -136,9 +144,9 @@ class KaldiDataGenerator: logger.info(f"Page {page_id} skipped, because it has no lines") return - full_image_url = res['zone']['image']['s3_url'] + full_image_url = image_data['s3_url'] if full_image_url is None: - full_image_url = res['zone']['image']['url'] + '/full/full/0/default.jpg' + full_image_url = image_data['url'] + '/full/full/0/default.jpg' img = self.get_image(full_image_url, page_id=page_id) @@ -175,16 +183,18 @@ class KaldiDataGenerator: dst2 = bg + dst return dst2 - def run_pages(self, page_ids: list): - for page_id in tqdm.tqdm(page_ids): + def run_pages(self, pages: list): + for page in tqdm.tqdm(pages): + page_id = page['id'] + image_data = page['zone']['image'] logger.debug(f"Page {page_id}") - self.extract_lines(page_id) + self.extract_lines(page_id, image_data) def run_volumes(self, volume_ids: list): for volume_id in tqdm.tqdm(volume_ids): logger.info(f"Volume {volume_id}") - page_ids = [page['id'] for page in api_client.paginate('ListElementChildren', id=volume_id)] - self.run_pages(page_ids) + pages = [page for page in api_client.paginate('ListElementChildren', id=volume_id)] + self.run_pages(pages) def run_folders(self, element_ids: list, volume_type: str): for elem_id in tqdm.tqdm(element_ids): @@ -266,7 +276,7 @@ def create_parser(): parser.add_argument('--train_ratio', type=float, default=0.8, help='Ratio of pages to be used in train (between 0 and 1)') parser.add_argument('--test_ratio', type=float, default=0.1, - help='Ratio of pages to be used in train (between 0 and 1 - train_ratio)') + help='Ratio of pages to be used in test (between 0 and 1 - train_ratio)') parser.add_argument('-e', '--extraction_mode', type=lambda x: Extraction[x], default=Extraction.boundingRect, help=f'Mode for extracting the line images: {[e.name for e in Extraction]}') @@ -284,8 +294,8 @@ def create_parser(): 'Elements of `volume_type` will be searched recursively in these folders') parser.add_argument('--volumes', nargs='*', help='List of volume ids to be used, separated by spaces') - parser.add_argument('--pages', nargs='*', - help='List of page ids to be used, separated by spaces') + # parser.add_argument('--pages', nargs='*', + # help='List of page ids to be used, separated by spaces') parser.add_argument('-v', '--volume_type', type=str, default='volume', help='Volumes (1 level above page) may have a different name on corpora') @@ -317,8 +327,8 @@ def main(): split_train_ratio=args.train_ratio, split_test_ratio=args.test_ratio) # extract all the lines and transcriptions - if args.pages: - kaldi_data_generator.run_pages(args.pages) + # if args.pages: + # kaldi_data_generator.run_pages(args.pages) if args.volumes: kaldi_data_generator.run_volumes(args.volumes) if args.folders: