diff --git a/kaldi_data_generator.py b/kaldi_data_generator.py index b3f99ffab7e8c826a0dbbe32248e9a57e9b40287..f3d9ce3a6685af216ed12bcdf00b0fc62115a884 100644 --- a/kaldi_data_generator.py +++ b/kaldi_data_generator.py @@ -8,6 +8,7 @@ import random from enum import Enum from io import BytesIO from pathlib import Path +from typing import Tuple import tqdm @@ -18,6 +19,8 @@ from PIL import Image from apistar.exceptions import ErrorResponse from arkindex import ArkindexClient, options_from_env +Box = Tuple[int, int, int, int] + logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s/%(name)s: %(message)s" @@ -89,9 +92,7 @@ class KaldiDataGenerator: def extract_lines(self, page_id: str): count = 0 - line_bounding_rects = [] - line_polygons = [] - line_transcriptions = [] + lines = [] try: for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'): if self.should_filter_by_slug and res['source']['slug'] not in self.accepted_slugs: @@ -99,11 +100,9 @@ class KaldiDataGenerator: text = res['text'] if not text or not text.strip(): continue - line_transcriptions.append(text) polygon = np.asarray(res['zone']['polygon']).clip(0) - line_polygons.append(polygon) [x, y, w, h] = cv2.boundingRect(polygon) - line_bounding_rects.append([x, y, w, h]) + lines.append(((x, y, w, h), polygon, text)) count += 1 except ErrorResponse as e: logger.info(f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}") @@ -119,24 +118,27 @@ class KaldiDataGenerator: img = self.get_image(full_image_url, page_id=page_id) + # sort vertically then horizontally + sorted_lines = sorted(lines, key=lambda key: (key[0][1], key[0][0])) + if self.extraction_mode == Extraction.boundingRect: - for i, [x, y, w, h] in enumerate(line_bounding_rects): + for i, ((x, y, w, h), polygon, text) in enumerate(sorted_lines): cropped = img[y:y + h, x:x + w].copy() cv2.imwrite(f'{self.out_line_img_dir}/{page_id}_{i}.jpg', cropped) elif self.extraction_mode == Extraction.polygon: - for i, (polygon, rect) in enumerate(zip(line_polygons, line_bounding_rects)): + for i, (rect, polygon, text) in enumerate(sorted_lines): polygon_img = self.extract_polygon_image(img, polygon=polygon, rect=rect) cv2.imwrite(f'{self.out_line_img_dir}/{page_id}_{i}.jpg', polygon_img) else: raise ValueError("Unsupported extraction mode") - for i, text in enumerate(line_transcriptions): + for i, (rect, polygon, text) in enumerate(sorted_lines): write_file(f"{self.out_line_text_dir}/{page_id}_{i}.txt", text) @staticmethod - def extract_polygon_image(img: 'np.ndarray', polygon: 'np.ndarray', rect: list) -> 'np.ndarray': + def extract_polygon_image(img: 'np.ndarray', polygon: 'np.ndarray', rect: Box) -> 'np.ndarray': pts = polygon.copy() [x, y, w, h] = rect cropped = img[y:y + h, x:x + w].copy()