Skip to content
Snippets Groups Projects
Commit 8a0edebd authored by Martin Maarand's avatar Martin Maarand
Browse files

Merge branch 'filter_elements' into 'master'

Filter elements according to their classes

See merge request teklia-projects/kaldi_data_generator!1
parents 9556bdbc 96c2464a
No related branches found
No related tags found
1 merge request!1Filter elements according to their classes
...@@ -9,7 +9,7 @@ from enum import Enum ...@@ -9,7 +9,7 @@ from enum import Enum
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Tuple from typing import Tuple
import time
import cv2 import cv2
import numpy as np import numpy as np
import requests import requests
...@@ -65,14 +65,16 @@ class Extraction(Enum): ...@@ -65,14 +65,16 @@ class Extraction(Enum):
class KaldiDataGenerator: class KaldiDataGenerator:
def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True, def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True,
extraction=Extraction.boundingRect, accepted_slugs=None): extraction=Extraction.boundingRect, accepted_slugs=None, accepted_classes=None, filter_printed=False):
self.out_dir_base = out_dir_base self.out_dir_base = out_dir_base
self.dataset_name = dataset_name self.dataset_name = dataset_name
self.grayscale = grayscale self.grayscale = grayscale
self.extraction_mode = extraction self.extraction_mode = extraction
self.accepted_slugs = accepted_slugs self.accepted_slugs = accepted_slugs
self.should_filter_by_slug = bool(self.accepted_slugs) self.should_filter_by_slug = bool(self.accepted_slugs)
self.accepted_classes = accepted_classes
self.should_filter_by_class = bool(self.accepted_classes)
self.should_filter_printed = filter_printed
self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name) self.out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name)
os.makedirs(self.out_line_text_dir, exist_ok=True) os.makedirs(self.out_line_text_dir, exist_ok=True)
self.out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name) self.out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name)
...@@ -96,9 +98,29 @@ class KaldiDataGenerator: ...@@ -96,9 +98,29 @@ class KaldiDataGenerator:
count = 0 count = 0
lines = [] lines = []
try: try:
if self.should_filter_by_class:
accepted_zones = []
for elt in api_client.paginate('ListElementChildren',id=page_id, with_best_classes=True):
printed = True
for classification in elt['best_classes']:
if classification['ml_class']['name'] == 'handwritten':
printed = False
for classification in elt['best_classes']:
if classification['ml_class']['name'] in self.accepted_classes:
if self.should_filter_printed:
if not printed:
accepted_zones.append(elt['zone']['id'])
else:
accepted_zones.append(elt['zone']['id'])
logger.info('Number of accepted zone for page {} : {}'.format(page_id,len(accepted_zones)))
for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'): for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'):
if self.should_filter_by_slug and res['source']['slug'] not in self.accepted_slugs: if self.should_filter_by_slug and res['source']['slug'] not in self.accepted_slugs:
continue continue
if self.should_filter_by_class and res['zone']['id'] not in accepted_zones:
continue
text = res['text'] text = res['text']
if not text or not text.strip(): if not text or not text.strip():
continue continue
...@@ -269,7 +291,12 @@ def create_parser(): ...@@ -269,7 +291,12 @@ def create_parser():
parser.add_argument('--accepted_slugs', nargs='*', parser.add_argument('--accepted_slugs', nargs='*',
help='List of accepted slugs for downloading transcriptions') help='List of accepted slugs for downloading transcriptions')
parser.add_argument('--accepted_classes', nargs='*',
help='List of accepted ml_class names. Filter lines by class of related elements')
parser.add_argument('--filter_printed', action='store_true',
help='Filter lines annotated as printed')
return parser return parser
...@@ -282,7 +309,9 @@ def main(): ...@@ -282,7 +309,9 @@ def main():
out_dir_base=args.out_dir, out_dir_base=args.out_dir,
grayscale=args.grayscale, grayscale=args.grayscale,
extraction=args.extraction_mode, extraction=args.extraction_mode,
accepted_slugs=args.accepted_slugs) accepted_slugs=args.accepted_slugs,
accepted_classes=args.accepted_classes,
filter_printed=args.filter_printed)
kaldi_partitioner = KaldiPartitionSplitter(out_dir_base=args.out_dir, kaldi_partitioner = KaldiPartitionSplitter(out_dir_base=args.out_dir,
split_train_ratio=args.train_ratio, split_train_ratio=args.train_ratio,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment