diff --git a/kaldi_data_generator/kaldi_data_generator.py b/kaldi_data_generator/kaldi_data_generator.py index 922634374b9723c443e552bb9f3205cc20957b77..5156c7735c3adf1af8714b0d75bba6a490cf6ba4 100644 --- a/kaldi_data_generator/kaldi_data_generator.py +++ b/kaldi_data_generator/kaldi_data_generator.py @@ -30,6 +30,7 @@ api_client = ArkindexClient(**options_from_env()) SEED = 42 random.seed(SEED) MANUAL = "manual" +TEXT_LINE = "text_line" def download_image(url): @@ -75,7 +76,9 @@ class HTRDataGenerator: filter_printed=False, skip_vertical_lines=False, accepted_worker_version_ids=None, + transcription_type=TEXT_LINE, ): + self.module = module self.out_dir_base = out_dir_base self.dataset_name = dataset_name @@ -88,7 +91,7 @@ class HTRDataGenerator: self.accepted_worker_version_ids = accepted_worker_version_ids self.should_filter_by_worker = bool(self.accepted_worker_version_ids) self.should_filter_printed = filter_printed - + self.transcription_type = transcription_type self.skip_vertical_lines = skip_vertical_lines self.skipped_pages_count = 0 self.skipped_vertical_lines_count = 0 @@ -178,6 +181,8 @@ class HTRDataGenerator: and res["element"]["zone"]["id"] not in accepted_zones ): continue + if res["element"]["type"] != self.transcription_type: + continue text = res["text"] if not text or not text.strip(): @@ -479,6 +484,13 @@ def create_parser(): help=f"Mode for extracting the line images: {[e.name for e in Extraction]}", ) + parser.add_argument( + "--transcription_type", + type=str, + default="text_line", + help="Which type of elements' transcriptions to use? (page, paragraph, text_line, etc)", + ) + group = parser.add_mutually_exclusive_group(required=False) group.add_argument( "--grayscale", action="store_true", help="Convert images to grayscale" @@ -568,6 +580,7 @@ def main(): accepted_classes=args.accepted_classes, filter_printed=args.filter_printed, skip_vertical_lines=args.skip_vertical_lines, + transcription_type=args.transcription_type, accepted_worker_version_ids=args.accepted_worker_version_ids, )