From ada7640dfc175d18e544fdfa97a5f3c6b674220b Mon Sep 17 00:00:00 2001 From: Martin <maarand@teklia.com> Date: Fri, 8 Jan 2021 18:28:12 +0100 Subject: [PATCH] use text_line by default --- kaldi_data_generator/kaldi_data_generator.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/kaldi_data_generator/kaldi_data_generator.py b/kaldi_data_generator/kaldi_data_generator.py index 9226343..5156c77 100644 --- a/kaldi_data_generator/kaldi_data_generator.py +++ b/kaldi_data_generator/kaldi_data_generator.py @@ -30,6 +30,7 @@ api_client = ArkindexClient(**options_from_env()) SEED = 42 random.seed(SEED) MANUAL = "manual" +TEXT_LINE = "text_line" def download_image(url): @@ -75,7 +76,9 @@ class HTRDataGenerator: filter_printed=False, skip_vertical_lines=False, accepted_worker_version_ids=None, + transcription_type=TEXT_LINE, ): + self.module = module self.out_dir_base = out_dir_base self.dataset_name = dataset_name @@ -88,7 +91,7 @@ class HTRDataGenerator: self.accepted_worker_version_ids = accepted_worker_version_ids self.should_filter_by_worker = bool(self.accepted_worker_version_ids) self.should_filter_printed = filter_printed - + self.transcription_type = transcription_type self.skip_vertical_lines = skip_vertical_lines self.skipped_pages_count = 0 self.skipped_vertical_lines_count = 0 @@ -178,6 +181,8 @@ class HTRDataGenerator: and res["element"]["zone"]["id"] not in accepted_zones ): continue + if res["element"]["type"] != self.transcription_type: + continue text = res["text"] if not text or not text.strip(): @@ -479,6 +484,13 @@ def create_parser(): help=f"Mode for extracting the line images: {[e.name for e in Extraction]}", ) + parser.add_argument( + "--transcription_type", + type=str, + default="text_line", + help="Which type of elements' transcriptions to use? (page, paragraph, text_line, etc)", + ) + group = parser.add_mutually_exclusive_group(required=False) group.add_argument( "--grayscale", action="store_true", help="Convert images to grayscale" @@ -568,6 +580,7 @@ def main(): accepted_classes=args.accepted_classes, filter_printed=args.filter_printed, skip_vertical_lines=args.skip_vertical_lines, + transcription_type=args.transcription_type, accepted_worker_version_ids=args.accepted_worker_version_ids, ) -- GitLab