diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py index 45517b3f2ca2abb6790494e5edd581c1278b59bc..0b1ff4e3a1fd146c27e173232ca1ef2ea3192f50 100644 --- a/kaldi_data_generator/main.py +++ b/kaldi_data_generator/main.py @@ -82,6 +82,7 @@ class HTRDataGenerator: grayscale=True, extraction=Extraction.boundingRect, accepted_classes=None, + ignored_classes=None, style=None, skip_vertical_lines=False, accepted_worker_version_ids=None, @@ -102,7 +103,8 @@ class HTRDataGenerator: self.grayscale = grayscale self.extraction_mode = extraction self.accepted_classes = accepted_classes - self.should_filter_by_class = bool(self.accepted_classes) + self.ignored_classes = ignored_classes + self.should_filter_by_class = bool(self.accepted_classes) or bool(self.ignored_classes) self.accepted_worker_version_ids = accepted_worker_version_ids self.should_filter_by_worker = bool(self.accepted_worker_version_ids) self.style = style @@ -190,12 +192,17 @@ class HTRDataGenerator: should_accept = True if self.should_filter_by_class: # at first filter to only have elements with accepted classes - should_accept = False + # if accepted classes list is empty then should accept all + # except for ignored classes + should_accept = len(self.accepted_classes) == 0 for classification in elem_classes: class_name = classification["ml_class"]["name"] if class_name in self.accepted_classes: should_accept = True break + elif class_name in self.ignored_classes: + should_accept = False + break if not should_accept: continue @@ -838,10 +845,19 @@ def create_parser(): help="skips vertical lines when downloading", ) + parser.add_argument( + "--ignored_classes", + nargs="*", + default=[], + help="List of ignored ml_class names. Filter lines by class", + ) + + parser.add_argument( "--accepted_classes", nargs="*", - help="List of accepted ml_class names. Filter lines by class of related elements", + default=[], + help="List of accepted ml_class names. Filter lines by class", ) parser.add_argument( @@ -901,11 +917,18 @@ def main(): if not args.dataset_name and not args.split_only and not args.format == "kraken": parser.error("--dataset_name must be specified (unless --split-only)") - if args.style and args.accepted_classes: - if any(c in args.accepted_classes for c in STYLE_CLASSES): + if args.accepted_classes and args.ignored_classes: + if set(args.accepted_classes) & set(args.ignored_classes): + parser.error( + f"--accepted_classes and --ignored_classes values must not overlap ({args.accepted_classes} - {args.ignored_classes})" + ) + + if args.style and (args.accepted_classes or args.ignored_classes): + if set(STYLE_CLASSES) & (set(args.accepted_classes) | set(args.ignored_classes)): parser.error( f"--style class values ({STYLE_CLASSES}) shouldn't be in the accepted_classes list " - "if both --style and --accepted_classes are used together." + f"(or ignored_classes list) " + "if both --style and --accepted_classes (or --ignored_classes) are used together." ) logger.info(f"ARGS {args} \n") @@ -920,6 +943,7 @@ def main(): grayscale=args.grayscale, extraction=args.extraction_mode, accepted_classes=args.accepted_classes, + ignored_classes=args.ignored_classes, style=args.style, skip_vertical_lines=args.skip_vertical_lines, transcription_type=args.transcription_type,