Skip to content
Snippets Groups Projects
Commit a0c0910c authored by Chaza Abdelwahab's avatar Chaza Abdelwahab Committed by Martin
Browse files

applied filter by worker version commit

parent 1ccedc3b
No related branches found
No related tags found
1 merge request!6filtering by worker version id
......@@ -65,7 +65,8 @@ class Extraction(Enum):
class HTRDataGenerator:
def __init__(self, module, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True,
extraction=Extraction.boundingRect, accepted_slugs=None, accepted_classes=None, filter_printed=False, skip_vertical_lines=False):
extraction=Extraction.boundingRect, accepted_slugs=None, accepted_classes=None, filter_printed=False,
skip_vertical_lines=False, accepted_worker_version_ids=None,):
self.module = module
self.out_dir_base = out_dir_base
self.dataset_name = dataset_name
......@@ -75,11 +76,18 @@ class HTRDataGenerator:
self.should_filter_by_slug = bool(self.accepted_slugs)
self.accepted_classes = accepted_classes
self.should_filter_by_class = bool(self.accepted_classes)
self.accepted_worker_version_ids = accepted_worker_version_ids
self.should_filter_by_worker = bool(self.accepted_worker_version_ids)
self.should_filter_printed = filter_printed
self.skip_vertical_lines = skip_vertical_lines
self.skipped_pages_count = 0
self.skipped_vertical_lines_count = 0
self.accepted_lines_count = 0
if 'None' in self.accepted_worker_version_ids:
self.accepted_worker_version_ids[self.accepted_worker_version_ids.index('None')] = None
if self.module == 'kraken':
self.out_line_dir = out_dir_base
os.makedirs(self.out_line_dir, exist_ok=True)
......@@ -132,7 +140,8 @@ class HTRDataGenerator:
for res in api_client.paginate('ListTranscriptions', id=page_id, recursive=True):
if self.should_filter_by_slug and res['source']['slug'] not in self.accepted_slugs:
continue
if self.should_filter_by_worker and res['worker_version_id'] not in self.accepted_worker_version_ids:
continue
if self.should_filter_by_class and res['element']['zone']['id'] not in accepted_zones:
continue
......@@ -396,7 +405,10 @@ def create_parser():
parser.add_argument('--accepted_classes', nargs='*',
help='List of accepted ml_class names. Filter lines by class of related elements')
parser.add_argument('--accepted_worker_version_ids', nargs='*',
help='List of accepted worker version ids. Filter lines by worker version ids of related elements')
parser.add_argument('--filter_printed', action='store_true',
help='Filter lines annotated as printed')
return parser
......@@ -421,7 +433,8 @@ def main():
accepted_slugs=args.accepted_slugs,
accepted_classes=args.accepted_classes,
filter_printed=args.filter_printed,
skip_vertical_lines=args.skip_vertical_lines)
skip_vertical_lines=args.skip_vertical_lines,
accepted_worker_version_ids=args.accepted_worker_version_ids)
# extract all the lines and transcriptions
if args.pages:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment