diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py index 41230882a414ae74a915850634f846c98d648164..a9706ba3b2cf6a7f1ea44d5d4fceffee93e2b254 100644 --- a/kaldi_data_generator/main.py +++ b/kaldi_data_generator/main.py @@ -64,6 +64,15 @@ class Extraction(Enum): skew_min_area_rect: int = 6 +class Style(Enum): + handwritten: str = "handwritten" + typewritten: str = "typewritten" + other: str = "other" + + +STYLE_CLASSES = [s.name for s in [Style.handwritten, Style.typewritten]] + + class HTRDataGenerator: def __init__( self, @@ -73,7 +82,8 @@ class HTRDataGenerator: grayscale=True, extraction=Extraction.boundingRect, accepted_classes=None, - filter_printed=False, + ignored_classes=None, + style=None, skip_vertical_lines=False, accepted_worker_version_ids=None, transcription_type=TEXT_LINE, @@ -93,10 +103,14 @@ class HTRDataGenerator: self.grayscale = grayscale self.extraction_mode = extraction self.accepted_classes = accepted_classes - self.should_filter_by_class = bool(self.accepted_classes) + self.ignored_classes = ignored_classes + self.should_filter_by_class = bool(self.accepted_classes) or bool( + self.ignored_classes + ) self.accepted_worker_version_ids = accepted_worker_version_ids self.should_filter_by_worker = bool(self.accepted_worker_version_ids) - self.should_filter_printed = filter_printed + self.style = style + self.should_filter_by_style = bool(self.style) self.transcription_type = transcription_type self.skip_vertical_lines = skip_vertical_lines self.skipped_pages_count = 0 @@ -174,17 +188,49 @@ class HTRDataGenerator: for elt in self.api_client.cached_paginate( "ListElementChildren", id=page_id, with_classes=True ): - printed = True - for classification in elt["classes"]: - if classification["ml_class"]["name"] == "handwritten": - printed = False - for classification in elt["classes"]: - if classification["ml_class"]["name"] in self.accepted_classes: - if self.should_filter_printed: - if not printed: - accepted_zones.append(elt["zone"]["id"]) - else: - accepted_zones.append(elt["zone"]["id"]) + elem_classes = [c for c in elt["classes"] if c["state"] != "rejected"] + + should_accept = True + if self.should_filter_by_class: + # at first filter to only have elements with accepted classes + # if accepted classes list is empty then should accept all + # except for ignored classes + should_accept = len(self.accepted_classes) == 0 + for classification in elem_classes: + class_name = classification["ml_class"]["name"] + if class_name in self.accepted_classes: + should_accept = True + break + elif class_name in self.ignored_classes: + should_accept = False + break + + if not should_accept: + continue + + if self.should_filter_by_style: + style_counts = Counter() + for classification in elem_classes: + class_name = classification["ml_class"]["name"] + if class_name in STYLE_CLASSES: + style_counts[class_name] += 1 + + if len(style_counts) == 0: + # no handwritten or typewritten found, so other + found_class = Style.other + elif len(style_counts) == 1: + found_class = list(style_counts.keys())[0] + found_class = Style(found_class) + else: + raise ValueError( + f"Multiple style classes on the same element! {elt['id']} - {elem_classes}" + ) + + if found_class == self.style: + accepted_zones.append(elt["zone"]["id"]) + else: + accepted_zones.append(elt["zone"]["id"]) + logger.info( "Number of accepted zone for page {} : {}".format( page_id, len(accepted_zones) @@ -255,9 +301,8 @@ class HTRDataGenerator: and res["worker_version_id"] not in self.accepted_worker_version_ids ): continue - if ( - self.should_filter_by_class - and res["element"]["zone"]["id"] not in accepted_zones + if (self.should_filter_by_class or self.should_filter_by_style) and ( + res["element"]["zone"]["id"] not in accepted_zones ): continue if res["element"]["type"] != self.transcription_type: @@ -362,7 +407,7 @@ class HTRDataGenerator: cv2.imwrite(f"{self.out_line_img_dir}/{page_id}_{i}.jpg", line_img) def extract_lines(self, page_id: str, image_data: dict): - if self.should_filter_by_class: + if self.should_filter_by_class or self.should_filter_by_style: accepted_zones = self.get_accepted_zones(page_id) else: accepted_zones = [] @@ -796,10 +841,18 @@ def create_parser(): help="skips vertical lines when downloading", ) + parser.add_argument( + "--ignored_classes", + nargs="*", + default=[], + help="List of ignored ml_class names. Filter lines by class", + ) + parser.add_argument( "--accepted_classes", nargs="*", - help="List of accepted ml_class names. Filter lines by class of related elements", + default=[], + help="List of accepted ml_class names. Filter lines by class", ) parser.add_argument( @@ -815,9 +868,11 @@ def create_parser(): ) parser.add_argument( - "--filter_printed", - action="store_true", - help="Filter lines annotated as printed", + "--style", + type=lambda x: Style[x.lower()], + default=None, + help=f"Filter line images by style class. 'other' corresponds to line elements that " + f"have neither handwritten or typewritten class : {[s.name for s in Style]}", ) parser.add_argument( @@ -857,6 +912,22 @@ def main(): if not args.dataset_name and not args.split_only and not args.format == "kraken": parser.error("--dataset_name must be specified (unless --split-only)") + if args.accepted_classes and args.ignored_classes: + if set(args.accepted_classes) & set(args.ignored_classes): + parser.error( + f"--accepted_classes and --ignored_classes values must not overlap ({args.accepted_classes} - {args.ignored_classes})" + ) + + if args.style and (args.accepted_classes or args.ignored_classes): + if set(STYLE_CLASSES) & ( + set(args.accepted_classes) | set(args.ignored_classes) + ): + parser.error( + f"--style class values ({STYLE_CLASSES}) shouldn't be in the accepted_classes list " + f"(or ignored_classes list) " + "if both --style and --accepted_classes (or --ignored_classes) are used together." + ) + logger.info(f"ARGS {args} \n") api_client = create_api_client(args.cache_dir) @@ -869,7 +940,8 @@ def main(): grayscale=args.grayscale, extraction=args.extraction_mode, accepted_classes=args.accepted_classes, - filter_printed=args.filter_printed, + ignored_classes=args.ignored_classes, + style=args.style, skip_vertical_lines=args.skip_vertical_lines, transcription_type=args.transcription_type, accepted_worker_version_ids=args.accepted_worker_version_ids,