Skip to content
Snippets Groups Projects
Commit 8a0edebd authored by Martin Maarand's avatar Martin Maarand
Browse files

Merge branch 'filter_elements' into 'master'

Filter elements according to their classes

See merge request teklia-projects/kaldi_data_generator!1
parents 9556bdbc 96c2464a
No related branches found
No related tags found
1 merge request!1Filter elements according to their classes
......@@ -9,7 +9,7 @@ from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import Tuple
import time
import cv2
import numpy as np
import requests
......@@ -65,14 +65,16 @@ class Extraction(Enum):
class KaldiDataGenerator:
def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True,
extraction=Extraction.boundingRect, accepted_slugs=None):
extraction=Extraction.boundingRect, accepted_slugs=None, accepted_classes=None, filter_printed=False):
self.out_dir_base = out_dir_base
self.dataset_name = dataset_name
self.grayscale = grayscale
self.extraction_mode = extraction
self.accepted_slugs = accepted_slugs
self.should_filter_by_slug = bool(self.accepted_slugs)
self.accepted_classes = accepted_classes
self.should_filter_by_class = bool(self.accepted_classes)
self.should_filter_printed = filter_printed
self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name)
os.makedirs(self.out_line_text_dir, exist_ok=True)
self.out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name)
......@@ -96,9 +98,29 @@ class KaldiDataGenerator:
count = 0
lines = []
try:
if self.should_filter_by_class:
accepted_zones = []
for elt in api_client.paginate('ListElementChildren',id=page_id, with_best_classes=True):
printed = True
for classification in elt['best_classes']:
if classification['ml_class']['name'] == 'handwritten':
printed = False
for classification in elt['best_classes']:
if classification['ml_class']['name'] in self.accepted_classes:
if self.should_filter_printed:
if not printed:
accepted_zones.append(elt['zone']['id'])
else:
accepted_zones.append(elt['zone']['id'])
logger.info('Number of accepted zone for page {} : {}'.format(page_id,len(accepted_zones)))
for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'):
if self.should_filter_by_slug and res['source']['slug'] not in self.accepted_slugs:
continue
if self.should_filter_by_class and res['zone']['id'] not in accepted_zones:
continue
text = res['text']
if not text or not text.strip():
continue
......@@ -269,7 +291,12 @@ def create_parser():
parser.add_argument('--accepted_slugs', nargs='*',
help='List of accepted slugs for downloading transcriptions')
parser.add_argument('--accepted_classes', nargs='*',
help='List of accepted ml_class names. Filter lines by class of related elements')
parser.add_argument('--filter_printed', action='store_true',
help='Filter lines annotated as printed')
return parser
......@@ -282,7 +309,9 @@ def main():
out_dir_base=args.out_dir,
grayscale=args.grayscale,
extraction=args.extraction_mode,
accepted_slugs=args.accepted_slugs)
accepted_slugs=args.accepted_slugs,
accepted_classes=args.accepted_classes,
filter_printed=args.filter_printed)
kaldi_partitioner = KaldiPartitionSplitter(out_dir_base=args.out_dir,
split_train_ratio=args.train_ratio,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment