Skip to content
Snippets Groups Projects
Commit c5e4f26e authored by Martin's avatar Martin
Browse files

filter text_lines by style class - handwritten, typewritten or neither (other)

parent 3fbdf930
No related branches found
No related tags found
1 merge request!22Add style filter (handwritten, typewritten); support ignored_classes
Pipeline #74327 failed
......@@ -64,6 +64,15 @@ class Extraction(Enum):
skew_min_area_rect: int = 6
class Style(Enum):
handwritten: str = "handwritten"
typewritten: str = "typewritten"
other: str = "other"
STYLE_CLASSES = [s.name for s in [Style.handwritten, Style.typewritten]]
class HTRDataGenerator:
def __init__(
self,
......@@ -73,7 +82,7 @@ class HTRDataGenerator:
grayscale=True,
extraction=Extraction.boundingRect,
accepted_classes=None,
filter_printed=False,
style=None,
skip_vertical_lines=False,
accepted_worker_version_ids=None,
transcription_type=TEXT_LINE,
......@@ -96,7 +105,8 @@ class HTRDataGenerator:
self.should_filter_by_class = bool(self.accepted_classes)
self.accepted_worker_version_ids = accepted_worker_version_ids
self.should_filter_by_worker = bool(self.accepted_worker_version_ids)
self.should_filter_printed = filter_printed
self.style = style
self.should_filter_by_style = bool(self.style)
self.transcription_type = transcription_type
self.skip_vertical_lines = skip_vertical_lines
self.skipped_pages_count = 0
......@@ -174,17 +184,47 @@ class HTRDataGenerator:
for elt in self.api_client.cached_paginate(
"ListElementChildren", id=page_id, with_classes=True
):
printed = True
for classification in elt["classes"]:
if classification["ml_class"]["name"] == "handwritten":
printed = False
for classification in elt["classes"]:
if classification["ml_class"]["name"] in self.accepted_classes:
if self.should_filter_printed:
if not printed:
accepted_zones.append(elt["zone"]["id"])
else:
accepted_zones.append(elt["zone"]["id"])
elem_classes = [c
for c in elt["classes"]
if c["state"] != "rejected"
]
should_accept = True
if self.should_filter_by_class:
# at first filter to only have elements with accepted classes
should_accept = False
for classification in elem_classes:
class_name = classification["ml_class"]["name"]
if class_name in self.accepted_classes:
should_accept = True
break
if not should_accept:
continue
if self.should_filter_by_style:
style_counts = Counter()
for classification in elem_classes:
class_name = classification["ml_class"]["name"]
if class_name in STYLE_CLASSES:
style_counts[class_name] += 1
print("STYLE", style_counts)
if len(style_counts) == 0:
# no handwritten or typewritten found, so other
found_class = Style.other
elif len(style_counts) == 1:
found_class = list(style_counts.keys())[0]
found_class = Style(found_class)
else:
raise ValueError(f"Multiple style classes on the same element! {elt['id']} - {elem_classes}")
if found_class == self.style:
accepted_zones.append(elt["zone"]["id"])
else:
accepted_zones.append(elt["zone"]["id"])
logger.info(
"Number of accepted zone for page {} : {}".format(
page_id, len(accepted_zones)
......@@ -197,6 +237,7 @@ class HTRDataGenerator:
)
raise e
def _validate_transcriptions(self, page_id: str, lines: List[TranscriptionData]):
if not lines:
return
......@@ -256,7 +297,7 @@ class HTRDataGenerator:
):
continue
if (
self.should_filter_by_class
(self.should_filter_by_class or self.should_filter_by_style)
and res["element"]["zone"]["id"] not in accepted_zones
):
continue
......@@ -362,7 +403,7 @@ class HTRDataGenerator:
cv2.imwrite(f"{self.out_line_img_dir}/{page_id}_{i}.jpg", line_img)
def extract_lines(self, page_id: str, image_data: dict):
if self.should_filter_by_class:
if self.should_filter_by_class or self.should_filter_by_style:
accepted_zones = self.get_accepted_zones(page_id)
else:
accepted_zones = []
......@@ -815,9 +856,11 @@ def create_parser():
)
parser.add_argument(
"--filter_printed",
action="store_true",
help="Filter lines annotated as printed",
"--style",
type=lambda x: Style[x.lower()],
default=None,
help=f"Filter line images by style class. 'other' corresponds to line elements that "
f"have neither handwritten or typewritten class : {[s.name for s in Style]}",
)
parser.add_argument(
......@@ -857,6 +900,11 @@ def main():
if not args.dataset_name and not args.split_only and not args.format == "kraken":
parser.error("--dataset_name must be specified (unless --split-only)")
if args.style and args.accepted_classes:
if any(c in args.accepted_classes for c in STYLE_CLASSES):
parser.error(f"--style class values ({STYLE_CLASSES}) shouldn't be in the accepted_classes list "
"if both --style and --accepted_classes are used together.")
logger.info(f"ARGS {args} \n")
api_client = create_api_client(args.cache_dir)
......@@ -869,7 +917,7 @@ def main():
grayscale=args.grayscale,
extraction=args.extraction_mode,
accepted_classes=args.accepted_classes,
filter_printed=args.filter_printed,
style=args.style,
skip_vertical_lines=args.skip_vertical_lines,
transcription_type=args.transcription_type,
accepted_worker_version_ids=args.accepted_worker_version_ids,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment