From d732d9420bc2390e25ca65c9a745f801916e3411 Mon Sep 17 00:00:00 2001 From: Martin <maarand@teklia.com> Date: Thu, 5 Aug 2021 13:52:40 +0200 Subject: [PATCH] don't filter vertical lines that have a rotation class --- kaldi_data_generator/main.py | 27 ++++++++++++++++----------- kaldi_data_generator/utils.py | 9 +++++++++ 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py index cd04a86..69f1f63 100644 --- a/kaldi_data_generator/main.py +++ b/kaldi_data_generator/main.py @@ -169,8 +169,6 @@ class HTRDataGenerator: raise e def get_transcriptions(self, page_id: str, accepted_zones): - count = 0 - count_skipped = 0 lines = [] try: for res in self.api_client.paginate( @@ -210,14 +208,8 @@ class HTRDataGenerator: polygon=polygon, text=text, ) - if self.skip_vertical_lines: - rect = trans_data.rect - if rect.height > rect.width: - count_skipped += 1 - continue lines.append(trans_data) - count += 1 if self.should_rotate: classes_by_elem = self.get_children_classes(page_id) @@ -237,7 +229,20 @@ class HTRDataGenerator: else: logger.warning(f"No rotation classes on {trans.element_id}") - return (lines, count, count_skipped) + count_skipped = 0 + if self.skip_vertical_lines: + filtered_lines = [] + for line in lines: + if line.is_vertical: + count_skipped += 1 + continue + filtered_lines.append(line) + + lines = filtered_lines + + count = len(lines) + + return lines, count, count_skipped except ErrorResponse as e: logger.info( @@ -769,9 +774,9 @@ def main(): skipped_ratio = data_generator.skipped_vertical_lines_count / ( data_generator.skipped_vertical_lines_count + data_generator.accepted_lines_count - ) + ) * 100 logger.info( - f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({skipped_ratio}/1.0)" + f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({round(skipped_ratio, 2)}%)" ) else: logger.info("Creating a split from already downloaded files") diff --git a/kaldi_data_generator/utils.py b/kaldi_data_generator/utils.py index c565ee3..671912c 100644 --- a/kaldi_data_generator/utils.py +++ b/kaldi_data_generator/utils.py @@ -25,6 +25,15 @@ class TranscriptionData: self.rect = BoundingBox._make(cv2.boundingRect(self.polygon)) + @property + def is_vertical(self) -> bool: + """ + Used to filter out vertical lines. Will be ignored when rotation class is given. + """ + if self.rotation_class is None: + return self.rect.height > self.rect.width + return False + def __repr__(self): return str(vars(self)) -- GitLab