Skip to content
Snippets Groups Projects
Commit 2f412361 authored by Martin's avatar Martin
Browse files

sort lines before extracting

parent aac3d0f7
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@ import random
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import Tuple
import tqdm
......@@ -18,6 +19,8 @@ from PIL import Image
from apistar.exceptions import ErrorResponse
from arkindex import ArkindexClient, options_from_env
Box = Tuple[int, int, int, int]
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s/%(name)s: %(message)s"
......@@ -89,9 +92,7 @@ class KaldiDataGenerator:
def extract_lines(self, page_id: str):
count = 0
line_bounding_rects = []
line_polygons = []
line_transcriptions = []
lines = []
try:
for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'):
if self.should_filter_by_slug and res['source']['slug'] not in self.accepted_slugs:
......@@ -99,11 +100,9 @@ class KaldiDataGenerator:
text = res['text']
if not text or not text.strip():
continue
line_transcriptions.append(text)
polygon = np.asarray(res['zone']['polygon']).clip(0)
line_polygons.append(polygon)
[x, y, w, h] = cv2.boundingRect(polygon)
line_bounding_rects.append([x, y, w, h])
lines.append(((x, y, w, h), polygon, text))
count += 1
except ErrorResponse as e:
logger.info(f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}")
......@@ -119,24 +118,27 @@ class KaldiDataGenerator:
img = self.get_image(full_image_url, page_id=page_id)
# sort vertically then horizontally
sorted_lines = sorted(lines, key=lambda key: (key[0][1], key[0][0]))
if self.extraction_mode == Extraction.boundingRect:
for i, [x, y, w, h] in enumerate(line_bounding_rects):
for i, ((x, y, w, h), polygon, text) in enumerate(sorted_lines):
cropped = img[y:y + h, x:x + w].copy()
cv2.imwrite(f'{self.out_line_img_dir}/{page_id}_{i}.jpg', cropped)
elif self.extraction_mode == Extraction.polygon:
for i, (polygon, rect) in enumerate(zip(line_polygons, line_bounding_rects)):
for i, (rect, polygon, text) in enumerate(sorted_lines):
polygon_img = self.extract_polygon_image(img, polygon=polygon, rect=rect)
cv2.imwrite(f'{self.out_line_img_dir}/{page_id}_{i}.jpg', polygon_img)
else:
raise ValueError("Unsupported extraction mode")
for i, text in enumerate(line_transcriptions):
for i, (rect, polygon, text) in enumerate(sorted_lines):
write_file(f"{self.out_line_text_dir}/{page_id}_{i}.txt", text)
@staticmethod
def extract_polygon_image(img: 'np.ndarray', polygon: 'np.ndarray', rect: list) -> 'np.ndarray':
def extract_polygon_image(img: 'np.ndarray', polygon: 'np.ndarray', rect: Box) -> 'np.ndarray':
pts = polygon.copy()
[x, y, w, h] = rect
cropped = img[y:y + h, x:x + w].copy()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment