From 681f7ff54bf8e6c96ccda1723232aa37dbccf07e Mon Sep 17 00:00:00 2001
From: Martin Maarand <maarand@teklia.com>
Date: Thu, 5 Aug 2021 11:57:05 +0000
Subject: [PATCH] Don't filter vertical lines with rotation class

---
 kaldi_data_generator/main.py  | 33 +++++++++++++++++++--------------
 kaldi_data_generator/utils.py |  9 +++++++++
 2 files changed, 28 insertions(+), 14 deletions(-)

diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py
index cd04a86..8cedec1 100644
--- a/kaldi_data_generator/main.py
+++ b/kaldi_data_generator/main.py
@@ -169,8 +169,6 @@ class HTRDataGenerator:
             raise e
 
     def get_transcriptions(self, page_id: str, accepted_zones):
-        count = 0
-        count_skipped = 0
         lines = []
         try:
             for res in self.api_client.paginate(
@@ -210,14 +208,8 @@ class HTRDataGenerator:
                     polygon=polygon,
                     text=text,
                 )
-                if self.skip_vertical_lines:
-                    rect = trans_data.rect
-                    if rect.height > rect.width:
-                        count_skipped += 1
-                        continue
 
                 lines.append(trans_data)
-                count += 1
 
             if self.should_rotate:
                 classes_by_elem = self.get_children_classes(page_id)
@@ -237,7 +229,20 @@ class HTRDataGenerator:
                     else:
                         logger.warning(f"No rotation classes on {trans.element_id}")
 
-            return (lines, count, count_skipped)
+            count_skipped = 0
+            if self.skip_vertical_lines:
+                filtered_lines = []
+                for line in lines:
+                    if line.is_vertical:
+                        count_skipped += 1
+                        continue
+                    filtered_lines.append(line)
+
+                lines = filtered_lines
+
+            count = len(lines)
+
+            return lines, count, count_skipped
 
         except ErrorResponse as e:
             logger.info(
@@ -766,12 +771,12 @@ def main():
             logger.info(
                 f"Number of skipped pages: {data_generator.skipped_pages_count}"
             )
-            skipped_ratio = data_generator.skipped_vertical_lines_count / (
-                data_generator.skipped_vertical_lines_count
-                + data_generator.accepted_lines_count
-            )
+            _skipped_vertical_count = data_generator.skipped_vertical_lines_count
+            _total_count = _skipped_vertical_count + data_generator.accepted_lines_count
+            skipped_ratio = _skipped_vertical_count / _total_count * 100
+
             logger.info(
-                f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({skipped_ratio}/1.0)"
+                f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({round(skipped_ratio, 2)}%)"
             )
     else:
         logger.info("Creating a split from already downloaded files")
diff --git a/kaldi_data_generator/utils.py b/kaldi_data_generator/utils.py
index c565ee3..671912c 100644
--- a/kaldi_data_generator/utils.py
+++ b/kaldi_data_generator/utils.py
@@ -25,6 +25,15 @@ class TranscriptionData:
 
         self.rect = BoundingBox._make(cv2.boundingRect(self.polygon))
 
+    @property
+    def is_vertical(self) -> bool:
+        """
+        Used to filter out vertical lines. Will be ignored when rotation class is given.
+        """
+        if self.rotation_class is None:
+            return self.rect.height > self.rect.width
+        return False
+
     def __repr__(self):
         return str(vars(self))
 
-- 
GitLab