diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py index cd04a86817d1af9d7b5a2d8983efa63fe8af0cbe..8cedec131e3dfd045cb6edf623780a50aafd5d7f 100644 --- a/kaldi_data_generator/main.py +++ b/kaldi_data_generator/main.py @@ -169,8 +169,6 @@ class HTRDataGenerator: raise e def get_transcriptions(self, page_id: str, accepted_zones): - count = 0 - count_skipped = 0 lines = [] try: for res in self.api_client.paginate( @@ -210,14 +208,8 @@ class HTRDataGenerator: polygon=polygon, text=text, ) - if self.skip_vertical_lines: - rect = trans_data.rect - if rect.height > rect.width: - count_skipped += 1 - continue lines.append(trans_data) - count += 1 if self.should_rotate: classes_by_elem = self.get_children_classes(page_id) @@ -237,7 +229,20 @@ class HTRDataGenerator: else: logger.warning(f"No rotation classes on {trans.element_id}") - return (lines, count, count_skipped) + count_skipped = 0 + if self.skip_vertical_lines: + filtered_lines = [] + for line in lines: + if line.is_vertical: + count_skipped += 1 + continue + filtered_lines.append(line) + + lines = filtered_lines + + count = len(lines) + + return lines, count, count_skipped except ErrorResponse as e: logger.info( @@ -766,12 +771,12 @@ def main(): logger.info( f"Number of skipped pages: {data_generator.skipped_pages_count}" ) - skipped_ratio = data_generator.skipped_vertical_lines_count / ( - data_generator.skipped_vertical_lines_count - + data_generator.accepted_lines_count - ) + _skipped_vertical_count = data_generator.skipped_vertical_lines_count + _total_count = _skipped_vertical_count + data_generator.accepted_lines_count + skipped_ratio = _skipped_vertical_count / _total_count * 100 + logger.info( - f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({skipped_ratio}/1.0)" + f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({round(skipped_ratio, 2)}%)" ) else: logger.info("Creating a split from already downloaded files") diff --git a/kaldi_data_generator/utils.py b/kaldi_data_generator/utils.py index c565ee3abea55ad389f1f99b4142733794fbe132..671912c92463892d0376b5bfd072f6f0bec945e7 100644 --- a/kaldi_data_generator/utils.py +++ b/kaldi_data_generator/utils.py @@ -25,6 +25,15 @@ class TranscriptionData: self.rect = BoundingBox._make(cv2.boundingRect(self.polygon)) + @property + def is_vertical(self) -> bool: + """ + Used to filter out vertical lines. Will be ignored when rotation class is given. + """ + if self.rotation_class is None: + return self.rect.height > self.rect.width + return False + def __repr__(self): return str(vars(self))