Skip to content
Snippets Groups Projects
Commit 96c2464a authored by Raphael Toumi's avatar Raphael Toumi Committed by Martin Maarand
Browse files

Filter elements according to their classes

parent 9556bdbc
No related branches found
No related tags found
No related merge requests found
......@@ -9,7 +9,7 @@ from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import Tuple
import time
import cv2
import numpy as np
import requests
......@@ -65,14 +65,16 @@ class Extraction(Enum):
class KaldiDataGenerator:
def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True,
extraction=Extraction.boundingRect, accepted_slugs=None):
extraction=Extraction.boundingRect, accepted_slugs=None, accepted_classes=None, filter_printed=False):
self.out_dir_base = out_dir_base
self.dataset_name = dataset_name
self.grayscale = grayscale
self.extraction_mode = extraction
self.accepted_slugs = accepted_slugs
self.should_filter_by_slug = bool(self.accepted_slugs)
self.accepted_classes = accepted_classes
self.should_filter_by_class = bool(self.accepted_classes)
self.should_filter_printed = filter_printed
self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name)
os.makedirs(self.out_line_text_dir, exist_ok=True)
self.out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name)
......@@ -96,9 +98,29 @@ class KaldiDataGenerator:
count = 0
lines = []
try:
if self.should_filter_by_class:
accepted_zones = []
for elt in api_client.paginate('ListElementChildren',id=page_id, with_best_classes=True):
printed = True
for classification in elt['best_classes']:
if classification['ml_class']['name'] == 'handwritten':
printed = False
for classification in elt['best_classes']:
if classification['ml_class']['name'] in self.accepted_classes:
if self.should_filter_printed:
if not printed:
accepted_zones.append(elt['zone']['id'])
else:
accepted_zones.append(elt['zone']['id'])
logger.info('Number of accepted zone for page {} : {}'.format(page_id,len(accepted_zones)))
for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'):
if self.should_filter_by_slug and res['source']['slug'] not in self.accepted_slugs:
continue
if self.should_filter_by_class and res['zone']['id'] not in accepted_zones:
continue
text = res['text']
if not text or not text.strip():
continue
......@@ -269,7 +291,12 @@ def create_parser():
parser.add_argument('--accepted_slugs', nargs='*',
help='List of accepted slugs for downloading transcriptions')
parser.add_argument('--accepted_classes', nargs='*',
help='List of accepted ml_class names. Filter lines by class of related elements')
parser.add_argument('--filter_printed', action='store_true',
help='Filter lines annotated as printed')
return parser
......@@ -282,7 +309,9 @@ def main():
out_dir_base=args.out_dir,
grayscale=args.grayscale,
extraction=args.extraction_mode,
accepted_slugs=args.accepted_slugs)
accepted_slugs=args.accepted_slugs,
accepted_classes=args.accepted_classes,
filter_printed=args.filter_printed)
kaldi_partitioner = KaldiPartitionSplitter(out_dir_base=args.out_dir,
split_train_ratio=args.train_ratio,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment