Skip to content
Snippets Groups Projects
Commit 9b537f02 authored by Martin Maarand's avatar Martin Maarand
Browse files

Add style filter (handwritten, typewritten); support ignored_classes

parent 3fbdf930
No related branches found
No related tags found
1 merge request!22Add style filter (handwritten, typewritten); support ignored_classes
Pipeline #74332 passed
......@@ -64,6 +64,15 @@ class Extraction(Enum):
skew_min_area_rect: int = 6
class Style(Enum):
handwritten: str = "handwritten"
typewritten: str = "typewritten"
other: str = "other"
STYLE_CLASSES = [s.name for s in [Style.handwritten, Style.typewritten]]
class HTRDataGenerator:
def __init__(
self,
......@@ -73,7 +82,8 @@ class HTRDataGenerator:
grayscale=True,
extraction=Extraction.boundingRect,
accepted_classes=None,
filter_printed=False,
ignored_classes=None,
style=None,
skip_vertical_lines=False,
accepted_worker_version_ids=None,
transcription_type=TEXT_LINE,
......@@ -93,10 +103,14 @@ class HTRDataGenerator:
self.grayscale = grayscale
self.extraction_mode = extraction
self.accepted_classes = accepted_classes
self.should_filter_by_class = bool(self.accepted_classes)
self.ignored_classes = ignored_classes
self.should_filter_by_class = bool(self.accepted_classes) or bool(
self.ignored_classes
)
self.accepted_worker_version_ids = accepted_worker_version_ids
self.should_filter_by_worker = bool(self.accepted_worker_version_ids)
self.should_filter_printed = filter_printed
self.style = style
self.should_filter_by_style = bool(self.style)
self.transcription_type = transcription_type
self.skip_vertical_lines = skip_vertical_lines
self.skipped_pages_count = 0
......@@ -174,17 +188,49 @@ class HTRDataGenerator:
for elt in self.api_client.cached_paginate(
"ListElementChildren", id=page_id, with_classes=True
):
printed = True
for classification in elt["classes"]:
if classification["ml_class"]["name"] == "handwritten":
printed = False
for classification in elt["classes"]:
if classification["ml_class"]["name"] in self.accepted_classes:
if self.should_filter_printed:
if not printed:
accepted_zones.append(elt["zone"]["id"])
else:
accepted_zones.append(elt["zone"]["id"])
elem_classes = [c for c in elt["classes"] if c["state"] != "rejected"]
should_accept = True
if self.should_filter_by_class:
# at first filter to only have elements with accepted classes
# if accepted classes list is empty then should accept all
# except for ignored classes
should_accept = len(self.accepted_classes) == 0
for classification in elem_classes:
class_name = classification["ml_class"]["name"]
if class_name in self.accepted_classes:
should_accept = True
break
elif class_name in self.ignored_classes:
should_accept = False
break
if not should_accept:
continue
if self.should_filter_by_style:
style_counts = Counter()
for classification in elem_classes:
class_name = classification["ml_class"]["name"]
if class_name in STYLE_CLASSES:
style_counts[class_name] += 1
if len(style_counts) == 0:
# no handwritten or typewritten found, so other
found_class = Style.other
elif len(style_counts) == 1:
found_class = list(style_counts.keys())[0]
found_class = Style(found_class)
else:
raise ValueError(
f"Multiple style classes on the same element! {elt['id']} - {elem_classes}"
)
if found_class == self.style:
accepted_zones.append(elt["zone"]["id"])
else:
accepted_zones.append(elt["zone"]["id"])
logger.info(
"Number of accepted zone for page {} : {}".format(
page_id, len(accepted_zones)
......@@ -255,9 +301,8 @@ class HTRDataGenerator:
and res["worker_version_id"] not in self.accepted_worker_version_ids
):
continue
if (
self.should_filter_by_class
and res["element"]["zone"]["id"] not in accepted_zones
if (self.should_filter_by_class or self.should_filter_by_style) and (
res["element"]["zone"]["id"] not in accepted_zones
):
continue
if res["element"]["type"] != self.transcription_type:
......@@ -362,7 +407,7 @@ class HTRDataGenerator:
cv2.imwrite(f"{self.out_line_img_dir}/{page_id}_{i}.jpg", line_img)
def extract_lines(self, page_id: str, image_data: dict):
if self.should_filter_by_class:
if self.should_filter_by_class or self.should_filter_by_style:
accepted_zones = self.get_accepted_zones(page_id)
else:
accepted_zones = []
......@@ -796,10 +841,18 @@ def create_parser():
help="skips vertical lines when downloading",
)
parser.add_argument(
"--ignored_classes",
nargs="*",
default=[],
help="List of ignored ml_class names. Filter lines by class",
)
parser.add_argument(
"--accepted_classes",
nargs="*",
help="List of accepted ml_class names. Filter lines by class of related elements",
default=[],
help="List of accepted ml_class names. Filter lines by class",
)
parser.add_argument(
......@@ -815,9 +868,11 @@ def create_parser():
)
parser.add_argument(
"--filter_printed",
action="store_true",
help="Filter lines annotated as printed",
"--style",
type=lambda x: Style[x.lower()],
default=None,
help=f"Filter line images by style class. 'other' corresponds to line elements that "
f"have neither handwritten or typewritten class : {[s.name for s in Style]}",
)
parser.add_argument(
......@@ -857,6 +912,22 @@ def main():
if not args.dataset_name and not args.split_only and not args.format == "kraken":
parser.error("--dataset_name must be specified (unless --split-only)")
if args.accepted_classes and args.ignored_classes:
if set(args.accepted_classes) & set(args.ignored_classes):
parser.error(
f"--accepted_classes and --ignored_classes values must not overlap ({args.accepted_classes} - {args.ignored_classes})"
)
if args.style and (args.accepted_classes or args.ignored_classes):
if set(STYLE_CLASSES) & (
set(args.accepted_classes) | set(args.ignored_classes)
):
parser.error(
f"--style class values ({STYLE_CLASSES}) shouldn't be in the accepted_classes list "
f"(or ignored_classes list) "
"if both --style and --accepted_classes (or --ignored_classes) are used together."
)
logger.info(f"ARGS {args} \n")
api_client = create_api_client(args.cache_dir)
......@@ -869,7 +940,8 @@ def main():
grayscale=args.grayscale,
extraction=args.extraction_mode,
accepted_classes=args.accepted_classes,
filter_printed=args.filter_printed,
ignored_classes=args.ignored_classes,
style=args.style,
skip_vertical_lines=args.skip_vertical_lines,
transcription_type=args.transcription_type,
accepted_worker_version_ids=args.accepted_worker_version_ids,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment