From c5e4f26e610f64177099a1796365128fc577d4ff Mon Sep 17 00:00:00 2001 From: Martin <maarand@teklia.com> Date: Tue, 15 Feb 2022 17:23:43 +0100 Subject: [PATCH] filter text_lines by style class - handwritten, typewritten or neither (other) --- kaldi_data_generator/main.py | 86 ++++++++++++++++++++++++++++-------- 1 file changed, 67 insertions(+), 19 deletions(-) diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py index 4123088..cf204cf 100644 --- a/kaldi_data_generator/main.py +++ b/kaldi_data_generator/main.py @@ -64,6 +64,15 @@ class Extraction(Enum): skew_min_area_rect: int = 6 +class Style(Enum): + handwritten: str = "handwritten" + typewritten: str = "typewritten" + other: str = "other" + + +STYLE_CLASSES = [s.name for s in [Style.handwritten, Style.typewritten]] + + class HTRDataGenerator: def __init__( self, @@ -73,7 +82,7 @@ class HTRDataGenerator: grayscale=True, extraction=Extraction.boundingRect, accepted_classes=None, - filter_printed=False, + style=None, skip_vertical_lines=False, accepted_worker_version_ids=None, transcription_type=TEXT_LINE, @@ -96,7 +105,8 @@ class HTRDataGenerator: self.should_filter_by_class = bool(self.accepted_classes) self.accepted_worker_version_ids = accepted_worker_version_ids self.should_filter_by_worker = bool(self.accepted_worker_version_ids) - self.should_filter_printed = filter_printed + self.style = style + self.should_filter_by_style = bool(self.style) self.transcription_type = transcription_type self.skip_vertical_lines = skip_vertical_lines self.skipped_pages_count = 0 @@ -174,17 +184,47 @@ class HTRDataGenerator: for elt in self.api_client.cached_paginate( "ListElementChildren", id=page_id, with_classes=True ): - printed = True - for classification in elt["classes"]: - if classification["ml_class"]["name"] == "handwritten": - printed = False - for classification in elt["classes"]: - if classification["ml_class"]["name"] in self.accepted_classes: - if self.should_filter_printed: - if not printed: - accepted_zones.append(elt["zone"]["id"]) - else: - accepted_zones.append(elt["zone"]["id"]) + elem_classes = [c + for c in elt["classes"] + if c["state"] != "rejected" + ] + + should_accept = True + if self.should_filter_by_class: + # at first filter to only have elements with accepted classes + should_accept = False + for classification in elem_classes: + class_name = classification["ml_class"]["name"] + if class_name in self.accepted_classes: + should_accept = True + break + + if not should_accept: + continue + + if self.should_filter_by_style: + style_counts = Counter() + for classification in elem_classes: + class_name = classification["ml_class"]["name"] + if class_name in STYLE_CLASSES: + style_counts[class_name] += 1 + + print("STYLE", style_counts) + + if len(style_counts) == 0: + # no handwritten or typewritten found, so other + found_class = Style.other + elif len(style_counts) == 1: + found_class = list(style_counts.keys())[0] + found_class = Style(found_class) + else: + raise ValueError(f"Multiple style classes on the same element! {elt['id']} - {elem_classes}") + + if found_class == self.style: + accepted_zones.append(elt["zone"]["id"]) + else: + accepted_zones.append(elt["zone"]["id"]) + logger.info( "Number of accepted zone for page {} : {}".format( page_id, len(accepted_zones) @@ -197,6 +237,7 @@ class HTRDataGenerator: ) raise e + def _validate_transcriptions(self, page_id: str, lines: List[TranscriptionData]): if not lines: return @@ -256,7 +297,7 @@ class HTRDataGenerator: ): continue if ( - self.should_filter_by_class + (self.should_filter_by_class or self.should_filter_by_style) and res["element"]["zone"]["id"] not in accepted_zones ): continue @@ -362,7 +403,7 @@ class HTRDataGenerator: cv2.imwrite(f"{self.out_line_img_dir}/{page_id}_{i}.jpg", line_img) def extract_lines(self, page_id: str, image_data: dict): - if self.should_filter_by_class: + if self.should_filter_by_class or self.should_filter_by_style: accepted_zones = self.get_accepted_zones(page_id) else: accepted_zones = [] @@ -815,9 +856,11 @@ def create_parser(): ) parser.add_argument( - "--filter_printed", - action="store_true", - help="Filter lines annotated as printed", + "--style", + type=lambda x: Style[x.lower()], + default=None, + help=f"Filter line images by style class. 'other' corresponds to line elements that " + f"have neither handwritten or typewritten class : {[s.name for s in Style]}", ) parser.add_argument( @@ -857,6 +900,11 @@ def main(): if not args.dataset_name and not args.split_only and not args.format == "kraken": parser.error("--dataset_name must be specified (unless --split-only)") + if args.style and args.accepted_classes: + if any(c in args.accepted_classes for c in STYLE_CLASSES): + parser.error(f"--style class values ({STYLE_CLASSES}) shouldn't be in the accepted_classes list " + "if both --style and --accepted_classes are used together.") + logger.info(f"ARGS {args} \n") api_client = create_api_client(args.cache_dir) @@ -869,7 +917,7 @@ def main(): grayscale=args.grayscale, extraction=args.extraction_mode, accepted_classes=args.accepted_classes, - filter_printed=args.filter_printed, + style=args.style, skip_vertical_lines=args.skip_vertical_lines, transcription_type=args.transcription_type, accepted_worker_version_ids=args.accepted_worker_version_ids, -- GitLab