Skip to content
Snippets Groups Projects
Commit 1ccedc3b authored by Martin Maarand's avatar Martin Maarand
Browse files

Merge branch 'skip_vertical_lines' into 'master'

Add option to skip vertical lines

See merge request teklia/kaldi_data_generator!4
parents 90175f84 45c9e42e
No related branches found
No related tags found
1 merge request!4Add option to skip vertical lines
......@@ -20,6 +20,8 @@ Use help to list possible parameters:
```bash
python kaldi_data_generator.py --help
```
There is also an option that skips all vertical transcriptions and it is `--skip_vertical_lines`
#### Kaldi format
Simple example:
```bash
......
......@@ -65,7 +65,7 @@ class Extraction(Enum):
class HTRDataGenerator:
def __init__(self, module, dataset_name='foo', out_dir_base='/tmp/kaldi_data', grayscale=True,
extraction=Extraction.boundingRect, accepted_slugs=None, accepted_classes=None, filter_printed=False):
extraction=Extraction.boundingRect, accepted_slugs=None, accepted_classes=None, filter_printed=False, skip_vertical_lines=False):
self.module = module
self.out_dir_base = out_dir_base
self.dataset_name = dataset_name
......@@ -76,6 +76,10 @@ class HTRDataGenerator:
self.accepted_classes = accepted_classes
self.should_filter_by_class = bool(self.accepted_classes)
self.should_filter_printed = filter_printed
self.skip_vertical_lines = skip_vertical_lines
self.skipped_pages_count = 0
self.skipped_vertical_lines_count = 0
self.accepted_lines_count = 0
if self.module == 'kraken':
self.out_line_dir = out_dir_base
os.makedirs(self.out_line_dir, exist_ok=True)
......@@ -98,7 +102,7 @@ class HTRDataGenerator:
out_full_img_path, format='jpeg')
img = cv2.imread(out_full_img_path)
return img
def get_accepted_zones(self, page_id: str):
try:
accepted_zones = []
......@@ -122,6 +126,7 @@ class HTRDataGenerator:
def get_transcriptions(self, page_id: str, accepted_zones):
count = 0
count_skipped = 0
lines = []
try:
for res in api_client.paginate('ListTranscriptions', id=page_id, recursive=True):
......@@ -144,9 +149,14 @@ class HTRDataGenerator:
polygon = np.asarray(polygon).clip(0)
[x, y, w, h] = cv2.boundingRect(polygon)
if self.skip_vertical_lines:
if h > w:
count_skipped += 1
continue
lines.append(((x, y, w, h), polygon, text))
count += 1
return (lines, count)
return (lines, count, count_skipped)
except ErrorResponse as e:
logger.info(f"ListTranscriptions failed {e.status_code} - {e.title} - {e.content} - {page_id}")
raise e
......@@ -156,12 +166,20 @@ class HTRDataGenerator:
accepted_zones = self.get_accepted_zones(page_id)
else:
accepted_zones = []
lines, count = self.get_transcriptions(page_id, accepted_zones)
logger.debug(f"Num of lines {count}")
lines, count, count_skipped = self.get_transcriptions(page_id, accepted_zones)
if count == 0:
self.skipped_pages_count += 1
logger.info(f"Page {page_id} skipped, because it has no lines")
return
logger.debug(f"Total num of lines {count + count_skipped}")
logger.debug(f"Num of accepted lines {count}")
logger.debug(f"Num of skipped lines {count_skipped}")
self.skipped_vertical_lines_count += count_skipped
self.accepted_lines_count += count
full_image_url = image_data['s3_url']
if full_image_url is None:
full_image_url = image_data['url'] + '/full/full/0/default.jpg'
......@@ -189,10 +207,9 @@ class HTRDataGenerator:
polygon_img = self.extract_polygon_image(img, polygon=polygon, rect=rect)
if self.module == 'kraken':
cv2.imwrite(f'{self.out_line_dir}/{page_id}_{i}.png', polygon_img)
f.write(f"{page_id}_{i}.png")
f.write(f"{page_id}_{i}.png\n")
else:
cv2.imwrite(f'{self.out_line_img_dir}/{page_id}_{i}.jpg', polygon_img)
f.close()
else:
raise ValueError("Unsupported extraction mode")
......@@ -201,6 +218,8 @@ class HTRDataGenerator:
write_file(f"{self.out_line_dir}/{page_id}_{i}.gt.txt", text)
else:
write_file(f"{self.out_line_text_dir}/{page_id}_{i}.txt", text)
if self.module == 'kraken':
f.close()
@staticmethod
def extract_polygon_image(img: 'np.ndarray', polygon: 'np.ndarray', rect: Box) -> 'np.ndarray':
......@@ -217,11 +236,19 @@ class HTRDataGenerator:
return dst2
def run_pages(self, pages: list):
for page in tqdm.tqdm(pages):
page_id = page['id']
image_data = page['zone']['image']
logger.debug(f"Page {page_id}")
self.extract_lines(page_id, image_data)
if all(isinstance(n, str) for n in pages):
for page in pages:
elt = api_client.request('RetrieveElement', id=page)
page_id = elt['id']
image_data = elt['zone']['image']
logger.debug(f"Page {page_id}")
self.extract_lines(page_id, image_data)
else:
for page in tqdm.tqdm(pages):
page_id = page['id']
image_data = page['zone']['image']
logger.debug(f"Page {page_id}")
self.extract_lines(page_id, image_data)
def run_volumes(self, volume_ids: list):
for volume_id in tqdm.tqdm(volume_ids):
......@@ -267,7 +294,7 @@ class KaldiPartitionSplitter:
def page_level_split(self, line_ids: list) -> dict:
page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids})
random.shuffle(page_ids)
random.Random(SEED).shuffle(page_ids)
page_count = len(page_ids)
train_page_ids = page_ids[:round(page_count * self.split_train_ratio)]
......@@ -357,17 +384,19 @@ def create_parser():
'Elements of `volume_type` will be searched recursively in these folders')
parser.add_argument('--volumes', nargs='*',
help='List of volume ids to be used, separated by spaces')
# parser.add_argument('--pages', nargs='*',
# help='List of page ids to be used, separated by spaces')
parser.add_argument('--pages', nargs='*',
help='List of page ids to be used, separated by spaces')
parser.add_argument('-v', '--volume_type', type=str, default='volume',
help='Volumes (1 level above page) may have a different name on corpora')
parser.add_argument('--skip_vertical_lines', action='store_true', default=False,
help="skips vertical lines when downloading")
parser.add_argument('--accepted_slugs', nargs='*',
help='List of accepted slugs for downloading transcriptions')
parser.add_argument('--accepted_classes', nargs='*',
help='List of accepted ml_class names. Filter lines by class of related elements')
parser.add_argument('--filter_printed', action='store_true',
help='Filter lines annotated as printed')
return parser
......@@ -377,7 +406,7 @@ def main():
parser = create_parser()
args = parser.parse_args()
if not args.dataset_name and not args.split_only and not args.module == "kraken":
if not args.dataset_name and not args.split_only and not args.format == "kraken":
parser.error("--dataset_name must be specified (unless --split-only)")
logger.info(f"ARGS {args} \n")
......@@ -391,17 +420,23 @@ def main():
extraction=args.extraction_mode,
accepted_slugs=args.accepted_slugs,
accepted_classes=args.accepted_classes,
filter_printed=args.filter_printed)
filter_printed=args.filter_printed,
skip_vertical_lines=args.skip_vertical_lines)
# extract all the lines and transcriptions
# if args.pages:
# kaldi_data_generator.run_pages(args.pages)
if args.pages:
data_generator.run_pages(args.pages)
if args.volumes:
data_generator.run_volumes(args.volumes)
if args.folders:
data_generator.run_folders(args.folders, args.volume_type)
if args.corpora:
data_generator.run_corpora(args.corpora, args.volume_type)
if data_generator.skipped_vertical_lines_count > 0:
logger.info(f"Number of skipped pages: {data_generator.skipped_pages_count}")
skipped_ratio = data_generator.skipped_vertical_lines_count / (
data_generator.skipped_vertical_lines_count + data_generator.accepted_lines_count)
logger.info(f"Skipped {data_generator.skipped_vertical_lines_count} vertical lines ({skipped_ratio}/1.0)")
else:
logger.info("Creating a split from already downloaded files")
if not args.no_split:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment