diff --git a/kaldi_data_generator.py b/kaldi_data_generator.py index c8a2cc0add99135971d940d747d9a8321cf84281..98f7420ddca81ac72f7e26d5e69f62a5e3ff33be 100644 --- a/kaldi_data_generator.py +++ b/kaldi_data_generator.py @@ -9,7 +9,7 @@ from enum import Enum from io import BytesIO from pathlib import Path from typing import Tuple -import time + import cv2 import numpy as np import requests @@ -99,20 +99,20 @@ class KaldiDataGenerator: lines = [] try: if self.should_filter_by_class: - accepted_zones = [] - for elt in api_client.paginate('ListElementChildren',id=page_id, with_best_classes=True): - printed = True - for classification in elt['best_classes']: - if classification['ml_class']['name'] == 'handwritten': - printed = False - for classification in elt['best_classes']: - if classification['ml_class']['name'] in self.accepted_classes: - if self.should_filter_printed: - if not printed: - accepted_zones.append(elt['zone']['id']) - else: + accepted_zones = [] + for elt in api_client.paginate('ListElementChildren', id=page_id, with_best_classes=True): + printed = True + for classification in elt['best_classes']: + if classification['ml_class']['name'] == 'handwritten': + printed = False + for classification in elt['best_classes']: + if classification['ml_class']['name'] in self.accepted_classes: + if self.should_filter_printed: + if not printed: accepted_zones.append(elt['zone']['id']) - logger.info('Number of accepted zone for page {} : {}'.format(page_id,len(accepted_zones))) + else: + accepted_zones.append(elt['zone']['id']) + logger.info('Number of accepted zone for page {} : {}'.format(page_id, len(accepted_zones))) for res in api_client.paginate('ListTranscriptions', id=page_id, type='line', recursive=True): if self.should_filter_by_slug and res['source']['slug'] not in self.accepted_slugs: @@ -120,7 +120,7 @@ class KaldiDataGenerator: if self.should_filter_by_class and res['zone']['id'] not in accepted_zones: continue - + text = res['text'] if not text or not text.strip(): continue @@ -224,7 +224,8 @@ class Split(Enum): class KaldiPartitionSplitter: - def __init__(self, out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1, use_existing_split=False): + def __init__(self, out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1, + use_existing_split=False): self.out_dir_base = out_dir_base self.split_train_ratio = split_train_ratio self.split_test_ratio = split_test_ratio @@ -331,7 +332,7 @@ def create_parser(): help='List of accepted ml_class names. Filter lines by class of related elements') parser.add_argument('--filter_printed', action='store_true', - help='Filter lines annotated as printed') + help='Filter lines annotated as printed') return parser @@ -341,13 +342,14 @@ def main(): logger.info(f"ARGS {args} \n") if not args.split_only: - kaldi_data_generator = KaldiDataGenerator(dataset_name=args.dataset_name, - out_dir_base=args.out_dir, - grayscale=args.grayscale, - extraction=args.extraction_mode, - accepted_slugs=args.accepted_slugs, - accepted_classes=args.accepted_classes, - filter_printed=args.filter_printed) + kaldi_data_generator = KaldiDataGenerator( + dataset_name=args.dataset_name, + out_dir_base=args.out_dir, + grayscale=args.grayscale, + extraction=args.extraction_mode, + accepted_slugs=args.accepted_slugs, + accepted_classes=args.accepted_classes, + filter_printed=args.filter_printed) # extract all the lines and transcriptions # if args.pages: @@ -361,10 +363,11 @@ def main(): else: logger.info("Creating a split from already downloaded files") - kaldi_partitioner = KaldiPartitionSplitter(out_dir_base=args.out_dir, - split_train_ratio=args.train_ratio, - split_test_ratio=args.test_ratio, - use_existing_split=args.use_existing_split) + kaldi_partitioner = KaldiPartitionSplitter( + out_dir_base=args.out_dir, + split_train_ratio=args.train_ratio, + split_test_ratio=args.test_ratio, + use_existing_split=args.use_existing_split) # create partitions from all the extracted data kaldi_partitioner.create_partitions()