Skip to content
Snippets Groups Projects
Commit 1ccedc3b authored by Martin Maarand's avatar Martin Maarand
Browse files

Merge branch 'skip_vertical_lines' into 'master'

Add option to skip vertical lines

See merge request teklia/kaldi_data_generator!4
parents 90175f84 45c9e42e
No related branches found
No related tags found
1 merge request!4Add option to skip vertical lines
...@@ -20,6 +20,8 @@ Use help to list possible parameters: ...@@ -20,6 +20,8 @@ Use help to list possible parameters:
```bash ```bash
python kaldi_data_generator.py --help python kaldi_data_generator.py --help
``` ```
There is also an option that skips all vertical transcriptions and it is `--skip_vertical_lines`
#### Kaldi format #### Kaldi format
Simple example: Simple example:
```bash ```bash
......
...@@ -65,7 +65,7 @@ class Extraction(Enum): ...@@ -65,7 +65,7 @@ class Extraction(Enum):
class HTRDataGenerator: class HTRDataGenerator:
def __init__(self, module, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True, def __init__(self, module, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True,
extraction=Extraction.boundingRect, accepted_slugs=None, accepted_classes=None, filter_printed=False): extraction=Extraction.boundingRect, accepted_slugs=None, accepted_classes=None, filter_printed=False, skip_vertical_lines=False):
self.module = module self.module = module
self.out_dir_base = out_dir_base self.out_dir_base = out_dir_base
self.dataset_name = dataset_name self.dataset_name = dataset_name
...@@ -76,6 +76,10 @@ class HTRDataGenerator: ...@@ -76,6 +76,10 @@ class HTRDataGenerator:
self.accepted_classes = accepted_classes self.accepted_classes = accepted_classes
self.should_filter_by_class = bool(self.accepted_classes) self.should_filter_by_class = bool(self.accepted_classes)
self.should_filter_printed = filter_printed self.should_filter_printed = filter_printed
self.skip_vertical_lines = skip_vertical_lines
self.skipped_pages_count = 0
self.skipped_vertical_lines_count = 0
self.accepted_lines_count = 0
if self.module == 'kraken': if self.module == 'kraken':
self.out_line_dir = out_dir_base self.out_line_dir = out_dir_base
os.makedirs(self.out_line_dir, exist_ok=True) os.makedirs(self.out_line_dir, exist_ok=True)
...@@ -98,7 +102,7 @@ class HTRDataGenerator: ...@@ -98,7 +102,7 @@ class HTRDataGenerator:
out_full_img_path, format='jpeg') out_full_img_path, format='jpeg')
img = cv2.imread(out_full_img_path) img = cv2.imread(out_full_img_path)
return img return img
def get_accepted_zones(self, page_id: str): def get_accepted_zones(self, page_id: str):
try: try:
accepted_zones = [] accepted_zones = []
...@@ -122,6 +126,7 @@ class HTRDataGenerator: ...@@ -122,6 +126,7 @@ class HTRDataGenerator:
def get_transcriptions(self, page_id: str, accepted_zones): def get_transcriptions(self, page_id: str, accepted_zones):
count = 0 count = 0
count_skipped = 0
lines = [] lines = []
try: try:
for res in api_client.paginate('ListTranscriptions', id=page_id, recursive=True): for res in api_client.paginate('ListTranscriptions', id=page_id, recursive=True):
...@@ -144,9 +149,14 @@ class HTRDataGenerator: ...@@ -144,9 +149,14 @@ class HTRDataGenerator:
polygon = np.asarray(polygon).clip(0) polygon = np.asarray(polygon).clip(0)
[x, y, w, h] = cv2.boundingRect(polygon) [x, y, w, h] = cv2.boundingRect(polygon)
if self.skip_vertical_lines:
if h > w:
count_skipped += 1
continue
lines.append(((x, y, w, h), polygon, text)) lines.append(((x, y, w, h), polygon, text))
count += 1 count += 1
return (lines, count) return (lines, count, count_skipped)
except ErrorResponse as e: except ErrorResponse as e:
logger.info(f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}") logger.info(f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}")
raise e raise e
...@@ -156,12 +166,20 @@ class HTRDataGenerator: ...@@ -156,12 +166,20 @@ class HTRDataGenerator:
accepted_zones = self.get_accepted_zones(page_id) accepted_zones = self.get_accepted_zones(page_id)
else: else:
accepted_zones = [] accepted_zones = []
lines, count = self.get_transcriptions(page_id, accepted_zones) lines, count, count_skipped = self.get_transcriptions(page_id, accepted_zones)
logger.debug(f"Num of lines {count}")
if count == 0: if count == 0:
self.skipped_pages_count += 1
logger.info(f"Page {page_id} skipped, because it has no lines") logger.info(f"Page {page_id} skipped, because it has no lines")
return return
logger.debug(f"Total num of lines {count + count_skipped}")
logger.debug(f"Num of accepted lines {count}")
logger.debug(f"Num of skipped lines {count_skipped}")
self.skipped_vertical_lines_count += count_skipped
self.accepted_lines_count += count
full_image_url = image_data['s3_url'] full_image_url = image_data['s3_url']
if full_image_url is None: if full_image_url is None:
full_image_url = image_data['url'] + '/full/full/0/default.jpg' full_image_url = image_data['url'] + '/full/full/0/default.jpg'
...@@ -189,10 +207,9 @@ class HTRDataGenerator: ...@@ -189,10 +207,9 @@ class HTRDataGenerator:
polygon_img = self.extract_polygon_image(img, polygon=polygon, rect=rect) polygon_img = self.extract_polygon_image(img, polygon=polygon, rect=rect)
if self.module == 'kraken': if self.module == 'kraken':
cv2.imwrite(f'{self.out_line_dir}/{page_id}_{i}.png', polygon_img) cv2.imwrite(f'{self.out_line_dir}/{page_id}_{i}.png', polygon_img)
f.write(f"{page_id}_{i}.png") f.write(f"{page_id}_{i}.png\n")
else: else:
cv2.imwrite(f'{self.out_line_img_dir}/{page_id}_{i}.jpg', polygon_img) cv2.imwrite(f'{self.out_line_img_dir}/{page_id}_{i}.jpg', polygon_img)
f.close()
else: else:
raise ValueError("Unsupported extraction mode") raise ValueError("Unsupported extraction mode")
...@@ -201,6 +218,8 @@ class HTRDataGenerator: ...@@ -201,6 +218,8 @@ class HTRDataGenerator:
write_file(f"{self.out_line_dir}/{page_id}_{i}.gt.txt", text) write_file(f"{self.out_line_dir}/{page_id}_{i}.gt.txt", text)
else: else:
write_file(f"{self.out_line_text_dir}/{page_id}_{i}.txt", text) write_file(f"{self.out_line_text_dir}/{page_id}_{i}.txt", text)
if self.module == 'kraken':
f.close()
@staticmethod @staticmethod
def extract_polygon_image(img: 'np.ndarray', polygon: 'np.ndarray', rect: Box) -> 'np.ndarray': def extract_polygon_image(img: 'np.ndarray', polygon: 'np.ndarray', rect: Box) -> 'np.ndarray':
...@@ -217,11 +236,19 @@ class HTRDataGenerator: ...@@ -217,11 +236,19 @@ class HTRDataGenerator:
return dst2 return dst2
def run_pages(self, pages: list): def run_pages(self, pages: list):
for page in tqdm.tqdm(pages): if all(isinstance(n, str) for n in pages):
page_id = page['id'] for page in pages:
image_data = page['zone']['image'] elt = api_client.request('RetrieveElement', id=page)
logger.debug(f"Page {page_id}") page_id = elt['id']
self.extract_lines(page_id, image_data) image_data = elt['zone']['image']
logger.debug(f"Page {page_id}")
self.extract_lines(page_id, image_data)
else:
for page in tqdm.tqdm(pages):
page_id = page['id']
image_data = page['zone']['image']
logger.debug(f"Page {page_id}")
self.extract_lines(page_id, image_data)
def run_volumes(self, volume_ids: list): def run_volumes(self, volume_ids: list):
for volume_id in tqdm.tqdm(volume_ids): for volume_id in tqdm.tqdm(volume_ids):
...@@ -267,7 +294,7 @@ class KaldiPartitionSplitter: ...@@ -267,7 +294,7 @@ class KaldiPartitionSplitter:
def page_level_split(self, line_ids: list) -> dict: def page_level_split(self, line_ids: list) -> dict:
page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids}) page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids})
random.shuffle(page_ids) random.Random(SEED).shuffle(page_ids)
page_count = len(page_ids) page_count = len(page_ids)
train_page_ids = page_ids[:round(page_count * self.split_train_ratio)] train_page_ids = page_ids[:round(page_count * self.split_train_ratio)]
...@@ -357,17 +384,19 @@ def create_parser(): ...@@ -357,17 +384,19 @@ def create_parser():
'Elements of `volume_type` will be searched recursively in these folders') 'Elements of `volume_type` will be searched recursively in these folders')
parser.add_argument('--volumes', nargs='*', parser.add_argument('--volumes', nargs='*',
help='List of volume ids to be used, separated by spaces') help='List of volume ids to be used, separated by spaces')
# parser.add_argument('--pages', nargs='*', parser.add_argument('--pages', nargs='*',
# help='List of page ids to be used, separated by spaces') help='List of page ids to be used, separated by spaces')
parser.add_argument('-v', '--volume_type', type=str, default='volume', parser.add_argument('-v', '--volume_type', type=str, default='volume',
help='Volumes (1 level above page) may have a different name on corpora') help='Volumes (1 level above page) may have a different name on corpora')
parser.add_argument('--skip_vertical_lines', action='store_true', default=False,
help="skips vertical lines when downloading")
parser.add_argument('--accepted_slugs', nargs='*', parser.add_argument('--accepted_slugs', nargs='*',
help='List of accepted slugs for downloading transcriptions') help='List of accepted slugs for downloading transcriptions')
parser.add_argument('--accepted_classes', nargs='*', parser.add_argument('--accepted_classes', nargs='*',
help='List of accepted ml_class names. Filter lines by class of related elements') help='List of accepted ml_class names. Filter lines by class of related elements')
parser.add_argument('--filter_printed', action='store_true', parser.add_argument('--filter_printed', action='store_true',
help='Filter lines annotated as printed') help='Filter lines annotated as printed')
return parser return parser
...@@ -377,7 +406,7 @@ def main(): ...@@ -377,7 +406,7 @@ def main():
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
if not args.dataset_name and not args.split_only and not args.module == "kraken": if not args.dataset_name and not args.split_only and not args.format == "kraken":
parser.error("--dataset_name must be specified (unless --split-only)") parser.error("--dataset_name must be specified (unless --split-only)")
logger.info(f"ARGS {args} \n") logger.info(f"ARGS {args} \n")
...@@ -391,17 +420,23 @@ def main(): ...@@ -391,17 +420,23 @@ def main():
extraction=args.extraction_mode, extraction=args.extraction_mode,
accepted_slugs=args.accepted_slugs, accepted_slugs=args.accepted_slugs,
accepted_classes=args.accepted_classes, accepted_classes=args.accepted_classes,
filter_printed=args.filter_printed) filter_printed=args.filter_printed,
skip_vertical_lines=args.skip_vertical_lines)
# extract all the lines and transcriptions # extract all the lines and transcriptions
# if args.pages: if args.pages:
# kaldi_data_generator.run_pages(args.pages) data_generator.run_pages(args.pages)
if args.volumes: if args.volumes:
data_generator.run_volumes(args.volumes) data_generator.run_volumes(args.volumes)
if args.folders: if args.folders:
data_generator.run_folders(args.folders, args.volume_type) data_generator.run_folders(args.folders, args.volume_type)
if args.corpora: if args.corpora:
data_generator.run_corpora(args.corpora, args.volume_type) data_generator.run_corpora(args.corpora, args.volume_type)
if data_generator.skipped_vertical_lines_count > 0:
logger.info(f"Number of skipped pages: {data_generator.skipped_pages_count}")
skipped_ratio = data_generator.skipped_vertical_lines_count / (
data_generator.skipped_vertical_lines_count + data_generator.accepted_lines_count)
logger.info(f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({skipped_ratio}/1.0)")
else: else:
logger.info("Creating a split from already downloaded files") logger.info("Creating a split from already downloaded files")
if not args.no_split: if not args.no_split:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment