From d732d9420bc2390e25ca65c9a745f801916e3411 Mon Sep 17 00:00:00 2001
From: Martin <maarand@teklia.com>
Date: Thu, 5 Aug 2021 13:52:40 +0200
Subject: [PATCH] don't filter vertical lines that have a rotation class

---
 kaldi_data_generator/main.py  | 27 ++++++++++++++++-----------
 kaldi_data_generator/utils.py |  9 +++++++++
 2 files changed, 25 insertions(+), 11 deletions(-)

diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py
index cd04a86..69f1f63 100644
--- a/kaldi_data_generator/main.py
+++ b/kaldi_data_generator/main.py
@@ -169,8 +169,6 @@ class HTRDataGenerator:
             raise e
 
     def get_transcriptions(self, page_id: str, accepted_zones):
-        count = 0
-        count_skipped = 0
         lines = []
         try:
             for res in self.api_client.paginate(
@@ -210,14 +208,8 @@ class HTRDataGenerator:
                     polygon=polygon,
                     text=text,
                 )
-                if self.skip_vertical_lines:
-                    rect = trans_data.rect
-                    if rect.height > rect.width:
-                        count_skipped += 1
-                        continue
 
                 lines.append(trans_data)
-                count += 1
 
             if self.should_rotate:
                 classes_by_elem = self.get_children_classes(page_id)
@@ -237,7 +229,20 @@ class HTRDataGenerator:
                     else:
                         logger.warning(f"No rotation classes on {trans.element_id}")
 
-            return (lines, count, count_skipped)
+            count_skipped = 0
+            if self.skip_vertical_lines:
+                filtered_lines = []
+                for line in lines:
+                    if line.is_vertical:
+                        count_skipped += 1
+                        continue
+                    filtered_lines.append(line)
+
+                lines = filtered_lines
+
+            count = len(lines)
+
+            return lines, count, count_skipped
 
         except ErrorResponse as e:
             logger.info(
@@ -769,9 +774,9 @@ def main():
             skipped_ratio = data_generator.skipped_vertical_lines_count / (
                 data_generator.skipped_vertical_lines_count
                 + data_generator.accepted_lines_count
-            )
+            ) * 100
             logger.info(
-                f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({skipped_ratio}/1.0)"
+                f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({round(skipped_ratio, 2)}%)"
             )
     else:
         logger.info("Creating a split from already downloaded files")
diff --git a/kaldi_data_generator/utils.py b/kaldi_data_generator/utils.py
index c565ee3..671912c 100644
--- a/kaldi_data_generator/utils.py
+++ b/kaldi_data_generator/utils.py
@@ -25,6 +25,15 @@ class TranscriptionData:
 
         self.rect = BoundingBox._make(cv2.boundingRect(self.polygon))
 
+    @property
+    def is_vertical(self) -> bool:
+        """
+        Used to filter out vertical lines. Will be ignored when rotation class is given.
+        """
+        if self.rotation_class is None:
+            return self.rect.height > self.rect.width
+        return False
+
     def __repr__(self):
         return str(vars(self))
 
-- 
GitLab