From 540884e943f861ee3f79bcd16d1d51e670154d8a Mon Sep 17 00:00:00 2001
From: Martin Maarand <maarand@teklia.com>
Date: Wed, 2 Jun 2021 09:39:29 +0000
Subject: [PATCH] Add deskew extraction

---
 .isort.cfg                                   |   2 +-
 kaldi_data_generator/image_utils.py          | 126 ++++++++++++++
 kaldi_data_generator/kaldi_data_generator.py | 164 +++++++++++--------
 kaldi_data_generator/utils.py                |  13 ++
 tests/test_first.py                          |   7 -
 tests/test_image_utils.py                    |  49 ++++++
 6 files changed, 288 insertions(+), 73 deletions(-)
 create mode 100644 kaldi_data_generator/image_utils.py
 create mode 100644 kaldi_data_generator/utils.py
 delete mode 100644 tests/test_first.py
 create mode 100644 tests/test_image_utils.py

diff --git a/.isort.cfg b/.isort.cfg
index c88754a..1b59f25 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -8,4 +8,4 @@ line_length = 88
 
 default_section=FIRSTPARTY
 known_first_party =
-known_third_party = PIL,apistar,arkindex,cv2,numpy,requests,setuptools,tqdm
+known_third_party = PIL,apistar,arkindex,cv2,numpy,pytest,requests,setuptools,tqdm
diff --git a/kaldi_data_generator/image_utils.py b/kaldi_data_generator/image_utils.py
new file mode 100644
index 0000000..fb5700d
--- /dev/null
+++ b/kaldi_data_generator/image_utils.py
@@ -0,0 +1,126 @@
+# -*- coding: utf-8 -*-
+import math
+from io import BytesIO
+from typing import Tuple, Union
+
+import cv2
+import numpy as np
+import requests
+from PIL import Image, ImageChops
+
+from kaldi_data_generator.utils import logger
+
+RIGHT_ANGLE = 90
+
+Box = Tuple[int, int, int, int]
+
+
+def download_image(url):
+    """
+    Download an image and open it with Pillow
+    """
+    assert url.startswith("http"), "Image URL must be HTTP(S)"
+    # Download the image
+    # Cannot use stream=True as urllib's responses do not support the seek(int) method,
+    # which is explicitly required by Image.open on file-like objects
+    resp = requests.get(url)
+    resp.raise_for_status()
+
+    # Preprocess the image and prepare it for classification
+    image = Image.open(BytesIO(resp.content))
+    logger.debug(
+        "Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
+    )
+
+    return image
+
+
+def extract_polygon_image(
+    img: "np.ndarray", polygon: "np.ndarray", rect: Box
+) -> "np.ndarray":
+    pts = polygon.copy()
+    [x, y, w, h] = rect
+    cropped = img[y : y + h, x : x + w].copy()
+    pts = pts - pts.min(axis=0)
+    mask = np.zeros(cropped.shape[:2], np.uint8)
+    cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)
+    dst = cv2.bitwise_and(cropped, cropped, mask=mask)
+    bg = np.ones_like(cropped, np.uint8) * 255
+    cv2.bitwise_not(bg, bg, mask=mask)
+    dst2 = bg + dst
+    return dst2
+
+
+def extract_min_area_rect_image(
+    img: "np.ndarray", polygon: "np.ndarray", rect: Box
+) -> "np.ndarray":
+    min_area_rect = cv2.minAreaRect(polygon)
+    # convert minimum area rect to polygon
+    box = cv2.boxPoints(min_area_rect)
+    box = np.int0(box)
+
+    # get min area rect image
+    box_img = extract_polygon_image(img, polygon=box, rect=rect)
+    return box_img
+
+
+# https://github.com/sbrunner/deskew
+def rotate(
+    image: np.ndarray, angle: float, background: Union[int, Tuple[int, int, int]]
+) -> np.ndarray:
+    old_width, old_height = image.shape[:2]
+    angle_radian = math.radians(angle)
+    width = abs(np.sin(angle_radian) * old_height) + abs(
+        np.cos(angle_radian) * old_width
+    )
+    height = abs(np.sin(angle_radian) * old_width) + abs(
+        np.cos(angle_radian) * old_height
+    )
+
+    image_center = tuple(np.array(image.shape[1::-1]) / 2)
+    rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
+    rot_mat[1, 2] += (width - old_width) / 2
+    rot_mat[0, 2] += (height - old_height) / 2
+    return cv2.warpAffine(
+        image, rot_mat, (int(round(height)), int(round(width))), borderValue=background
+    )
+
+
+# https://gist.githubusercontent.com/mattjmorrison/932345/raw/b45660bae541610f338bef715642b148c3c4d178/crop_and_resize.py
+def trim(img: np.ndarray, border: Union[int, Tuple[int, int, int]] = 255):
+    # TODO test if removing completely white rows (all pixels are 255) is faster
+    image = Image.fromarray(img)
+    background = Image.new(image.mode, image.size, border)
+    diff = ImageChops.difference(image, background)
+    bbox = diff.getbbox()
+    if bbox:
+        return image.crop(bbox)
+
+
+def determine_rotate_angle(polygon: "np.ndarray") -> float:
+    """
+    Use cv2.minAreaRect to get the angle of the minimal bounding rectangle
+    and convert that angle to rotation angle.
+    The polygon will be rotated by maximum of 45 degrees to either side.
+    :param polygon:
+    :return: rotation angle (-45, 45)
+    """
+    top_left, shape, angle = cv2.minAreaRect(polygon)
+
+    if abs(angle) > RIGHT_ANGLE - 1:
+        # correct rectangle (not rotated) gets angle = RIGHT_ANGLE from minAreaRect
+        # since no way to know whether it should be rotated it will be ignored
+        rotate_angle = 0
+    elif angle > 45:
+        rotate_angle = angle - RIGHT_ANGLE
+    elif angle < -45:
+        rotate_angle = angle + RIGHT_ANGLE
+    elif abs(angle) == 45:
+        # no way to know in which direction it should be rotated
+        rotate_angle = 0
+    else:
+        rotate_angle = angle
+
+    # logger.debug(f"ANGLE: {angle:.2f} => {rotate_angle:.2f}")
+
+    return rotate_angle
diff --git a/kaldi_data_generator/kaldi_data_generator.py b/kaldi_data_generator/kaldi_data_generator.py
index df09faf..2a1a03c 100644
--- a/kaldi_data_generator/kaldi_data_generator.py
+++ b/kaldi_data_generator/kaldi_data_generator.py
@@ -2,28 +2,26 @@
 # -*- coding: utf-8 -*-
 
 import argparse
-import logging
 import os
 import random
 from enum import Enum
-from io import BytesIO
 from pathlib import Path
-from typing import Tuple
 
 import cv2
 import numpy as np
-import requests
 import tqdm
 from apistar.exceptions import ErrorResponse
 from arkindex import ArkindexClient, options_from_env
-from PIL import Image
 
-Box = Tuple[int, int, int, int]
-
-logging.basicConfig(
-    level=logging.INFO, format="%(asctime)s %(levelname)s/%(name)s: %(message)s"
+from kaldi_data_generator.image_utils import (
+    determine_rotate_angle,
+    download_image,
+    extract_min_area_rect_image,
+    extract_polygon_image,
+    rotate,
+    trim,
 )
-logger = logging.getLogger(os.path.basename(__file__))
+from kaldi_data_generator.utils import logger, write_file
 
 api_client = ArkindexClient(**options_from_env())
 
@@ -31,36 +29,16 @@ SEED = 42
 random.seed(SEED)
 MANUAL = "manual"
 TEXT_LINE = "text_line"
-
-
-def download_image(url):
-    """
-    Download an image and open it with Pillow
-    """
-    assert url.startswith("http"), "Image URL must be HTTP(S)"
-    # Download the image
-    # Cannot use stream=True as urllib's responses do not support the seek(int) method,
-    # which is explicitly required by Image.open on file-like objects
-    resp = requests.get(url)
-    resp.raise_for_status()
-
-    # Preprocess the image and prepare it for classification
-    image = Image.open(BytesIO(resp.content))
-    logger.debug(
-        "Downloaded image {} - size={}x{}".format(url, image.size[0], image.size[1])
-    )
-
-    return image
-
-
-def write_file(file_name, content):
-    with open(file_name, "w") as f:
-        f.write(content)
+WHITE = 255
 
 
 class Extraction(Enum):
     boundingRect: int = 0
     polygon: int = 1
+    # minimum containing rectangle with an angle (cv2.min_area_rect)
+    min_area_rect: int = 2
+    deskew_polygon: int = 3
+    deskew_min_area_rect: int = 4
 
 
 class HTRDataGenerator:
@@ -77,6 +55,7 @@ class HTRDataGenerator:
         skip_vertical_lines=False,
         accepted_worker_version_ids=None,
         transcription_type=TEXT_LINE,
+        max_deskew_angle=45,
     ):
 
         self.module = module
@@ -96,6 +75,7 @@ class HTRDataGenerator:
         self.skipped_pages_count = 0
         self.skipped_vertical_lines_count = 0
         self.accepted_lines_count = 0
+        self.max_deskew_angle = max_deskew_angle
 
         if MANUAL in self.accepted_worker_version_ids:
             self.accepted_worker_version_ids[
@@ -211,6 +191,13 @@ class HTRDataGenerator:
             )
             raise e
 
+    def _save_line_image(self, page_id, i, line_img, manifest_fp=None):
+        if self.module == "kraken":
+            cv2.imwrite(f"{self.out_line_dir}/{page_id}_{i}.png", line_img)
+            manifest_fp.write(f"{page_id}_{i}.png\n")
+        else:
+            cv2.imwrite(f"{self.out_line_img_dir}/{page_id}_{i}.jpg", line_img)
+
     def extract_lines(self, page_id: str, image_data: dict):
         if self.should_filter_by_class:
             accepted_zones = self.get_accepted_zones(page_id)
@@ -240,35 +227,71 @@ class HTRDataGenerator:
         sorted_lines = sorted(lines, key=lambda key: (key[0][1], key[0][0]))
 
         if self.module == "kraken":
-            f = open(f"{self.out_line_dir}/manifest.txt", "a")
+            manifest_fp = open(f"{self.out_line_dir}/manifest.txt", "a")
             # append to file, not re-write it
+        else:
+            # not needed for kaldi
+            manifest_fp = None
 
         if self.extraction_mode == Extraction.boundingRect:
             for i, ((x, y, w, h), polygon, text) in enumerate(sorted_lines):
                 cropped = img[y : y + h, x : x + w].copy()
-                if self.module == "kraken":
-                    cv2.imwrite(f"{self.out_line_dir}/{page_id}_{i}.png", cropped)
-                    f.write(f"{page_id}_{i}.png\n")
-                else:
-                    cv2.imwrite(f"{self.out_line_img_dir}/{page_id}_{i}.jpg", cropped)
+                self._save_line_image(page_id, i, cropped, manifest_fp)
 
         elif self.extraction_mode == Extraction.polygon:
             for i, (rect, polygon, text) in enumerate(sorted_lines):
-                polygon_img = self.extract_polygon_image(
+                polygon_img = extract_polygon_image(img, polygon=polygon, rect=rect)
+                self._save_line_image(page_id, i, polygon_img, manifest_fp)
+
+        elif self.extraction_mode == Extraction.min_area_rect:
+            for i, (rect, polygon, text) in enumerate(sorted_lines):
+                min_rect_img = extract_min_area_rect_image(
                     img, polygon=polygon, rect=rect
                 )
-                if self.module == "kraken":
-                    cv2.imwrite(f"{self.out_line_dir}/{page_id}_{i}.png", polygon_img)
-                    f.write(f"{page_id}_{i}.png\n")
-                else:
-                    cv2.imwrite(
-                        f"{self.out_line_img_dir}/{page_id}_{i}.jpg", polygon_img
+
+                self._save_line_image(page_id, i, min_rect_img, manifest_fp)
+
+        elif self.extraction_mode == Extraction.deskew_polygon:
+            for i, (rect, polygon, text) in enumerate(sorted_lines):
+                # get angle from min area rect
+                rotate_angle = determine_rotate_angle(polygon)
+
+                if abs(rotate_angle) > self.max_deskew_angle:
+                    logger.warning(
+                        f"Deskew angle ({rotate_angle}) over the limit ({self.max_deskew_angle}), won't rotate"
                     )
+                    rotate_angle = 0
+
+                # get polygon image
+                polygon_img = extract_polygon_image(img, polygon=polygon, rect=rect)
+
+                trimmed_img = self.rotate_and_trim(polygon_img, rotate_angle)
+
+                self._save_line_image(page_id, i, trimmed_img, manifest_fp)
+
+        elif self.extraction_mode == Extraction.deskew_min_area_rect:
+            for i, (rect, polygon, text) in enumerate(sorted_lines):
+                # get angle from min area rect
+                rotate_angle = determine_rotate_angle(polygon)
+
+                if abs(rotate_angle) > self.max_deskew_angle:
+                    logger.warning(
+                        f"Deskew angle ({rotate_angle}) over the limit ({self.max_deskew_angle}), won't rotate"
+                    )
+                    rotate_angle = 0
+
+                min_rect_img = extract_min_area_rect_image(
+                    img, polygon=polygon, rect=rect
+                )
+
+                trimmed_img = self.rotate_and_trim(min_rect_img, rotate_angle)
+
+                self._save_line_image(page_id, i, trimmed_img, manifest_fp)
         else:
-            raise ValueError("Unsupported extraction mode")
+            raise ValueError(f"Unsupported extraction mode: {self.extraction_mode}")
 
         if self.module == "kraken":
-            f.close()
+            manifest_fp.close()
 
         for i, (rect, polygon, text) in enumerate(sorted_lines):
             if self.module == "kraken":
@@ -276,21 +299,22 @@ class HTRDataGenerator:
             else:
                 write_file(f"{self.out_line_text_dir}/{page_id}_{i}.txt", text)
 
-    @staticmethod
-    def extract_polygon_image(
-        img: "np.ndarray", polygon: "np.ndarray", rect: Box
-    ) -> "np.ndarray":
-        pts = polygon.copy()
-        [x, y, w, h] = rect
-        cropped = img[y : y + h, x : x + w].copy()
-        pts = pts - pts.min(axis=0)
-        mask = np.zeros(cropped.shape[:2], np.uint8)
-        cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)
-        dst = cv2.bitwise_and(cropped, cropped, mask=mask)
-        bg = np.ones_like(cropped, np.uint8) * 255
-        cv2.bitwise_not(bg, bg, mask=mask)
-        dst2 = bg + dst
-        return dst2
+    def rotate_and_trim(self, img, rotate_angle):
+        """
+        Rotate image by given an angle and trim extra whitespace left after rotating
+        """
+        if self.grayscale:
+            background = WHITE
+        else:
+            background = (WHITE, WHITE, WHITE)
+
+        # rotate polygon image
+        deskewed_img = rotate(img, rotate_angle, background)
+        # trim extra whitespace left after rotating
+        trimmed_img = trim(deskewed_img, background)
+        trimmed_img = np.array(trimmed_img)
+
+        return trimmed_img
 
     def run_pages(self, pages: list):
         if all(isinstance(n, str) for n in pages):
@@ -485,6 +509,15 @@ def create_parser():
         help=f"Mode for extracting the line images: {[e.name for e in Extraction]}",
     )
 
+    parser.add_argument(
+        "--max_deskew_angle",
+        type=int,
+        default=45,
+        help="Maximum angle by which deskewing is allowed to rotate the line image. "
+        "If the angle determined by deskew tool is bigger than max "
+        "then that line won't be deskewed/rotated.",
+    )
+
     parser.add_argument(
         "--transcription_type",
         type=str,
@@ -583,6 +616,7 @@ def main():
             skip_vertical_lines=args.skip_vertical_lines,
             transcription_type=args.transcription_type,
             accepted_worker_version_ids=args.accepted_worker_version_ids,
+            max_deskew_angle=args.max_deskew_angle,
         )
 
         # extract all the lines and transcriptions
diff --git a/kaldi_data_generator/utils.py b/kaldi_data_generator/utils.py
new file mode 100644
index 0000000..bc9f9bc
--- /dev/null
+++ b/kaldi_data_generator/utils.py
@@ -0,0 +1,13 @@
+# -*- coding: utf-8 -*-
+import logging
+import os
+
+logging.basicConfig(
+    level=logging.INFO, format="%(asctime)s %(levelname)s/%(name)s: %(message)s"
+)
+logger = logging.getLogger(os.path.basename(__file__))
+
+
+def write_file(file_name, content):
+    with open(file_name, "w") as f:
+        f.write(content)
diff --git a/tests/test_first.py b/tests/test_first.py
deleted file mode 100644
index f5ab6fc..0000000
--- a/tests/test_first.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from kaldi_data_generator.kaldi_data_generator import MANUAL
-
-
-def test_setup_correct():
-    assert MANUAL
diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py
new file mode 100644
index 0000000..7b91b11
--- /dev/null
+++ b/tests/test_image_utils.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+
+import cv2
+import numpy as np
+import pytest
+
+from kaldi_data_generator.image_utils import determine_rotate_angle
+
+
+@pytest.mark.parametrize(
+    "angle, expected_rotate_angle",
+    (
+        (-1, -1),
+        (0, 0),
+        (10, 10),
+        (44.9, 45),
+        (45.1, -45),
+        (45, 0),
+        (46, -44),
+        (50, -40),
+        (89, -1),
+        (90, 0),
+        (91, 1),
+        (134, 44),
+        (135, 0),
+        (136, -44),
+        (179, -1),
+        (180, 0),
+        (-180, 0),
+        (-179, 1),
+        (-91, -1),
+        (-90, 0),
+        (-46, 44),
+        (-45, 0),
+        (-44, -44),
+    ),
+)
+def test_determine_rotate_angle(angle, expected_rotate_angle):
+    top_left = [300, 300]
+    shape = [400, 100]
+    # create polygon with expected angle
+    box = cv2.boxPoints((top_left, shape, angle))
+    box = np.int0(box)
+    _, _, calc_angle = cv2.minAreaRect(box)
+    rotate_angle = determine_rotate_angle(box)
+
+    assert (
+        round(rotate_angle) == expected_rotate_angle
+    ), f"C, A, R: {calc_angle} === {angle} === {rotate_angle}"
-- 
GitLab