Skip to content
Snippets Groups Projects
Commit c5e4f26e authored by Martin's avatar Martin
Browse files

filter text_lines by style class - handwritten, typewritten or neither (other)

parent 3fbdf930
No related branches found
No related tags found
1 merge request!22Add style filter (handwritten, typewritten); support ignored_classes
Pipeline #74327 failed
...@@ -64,6 +64,15 @@ class Extraction(Enum): ...@@ -64,6 +64,15 @@ class Extraction(Enum):
skew_min_area_rect: int = 6 skew_min_area_rect: int = 6
class Style(Enum):
handwritten: str = "handwritten"
typewritten: str = "typewritten"
other: str = "other"
STYLE_CLASSES = [s.name for s in [Style.handwritten, Style.typewritten]]
class HTRDataGenerator: class HTRDataGenerator:
def __init__( def __init__(
self, self,
...@@ -73,7 +82,7 @@ class HTRDataGenerator: ...@@ -73,7 +82,7 @@ class HTRDataGenerator:
grayscale=True, grayscale=True,
extraction=Extraction.boundingRect, extraction=Extraction.boundingRect,
accepted_classes=None, accepted_classes=None,
filter_printed=False, style=None,
skip_vertical_lines=False, skip_vertical_lines=False,
accepted_worker_version_ids=None, accepted_worker_version_ids=None,
transcription_type=TEXT_LINE, transcription_type=TEXT_LINE,
...@@ -96,7 +105,8 @@ class HTRDataGenerator: ...@@ -96,7 +105,8 @@ class HTRDataGenerator:
self.should_filter_by_class = bool(self.accepted_classes) self.should_filter_by_class = bool(self.accepted_classes)
self.accepted_worker_version_ids = accepted_worker_version_ids self.accepted_worker_version_ids = accepted_worker_version_ids
self.should_filter_by_worker = bool(self.accepted_worker_version_ids) self.should_filter_by_worker = bool(self.accepted_worker_version_ids)
self.should_filter_printed = filter_printed self.style = style
self.should_filter_by_style = bool(self.style)
self.transcription_type = transcription_type self.transcription_type = transcription_type
self.skip_vertical_lines = skip_vertical_lines self.skip_vertical_lines = skip_vertical_lines
self.skipped_pages_count = 0 self.skipped_pages_count = 0
...@@ -174,17 +184,47 @@ class HTRDataGenerator: ...@@ -174,17 +184,47 @@ class HTRDataGenerator:
for elt in self.api_client.cached_paginate( for elt in self.api_client.cached_paginate(
"ListElementChildren", id=page_id, with_classes=True "ListElementChildren", id=page_id, with_classes=True
): ):
printed = True elem_classes = [c
for classification in elt["classes"]: for c in elt["classes"]
if classification["ml_class"]["name"] == "handwritten": if c["state"] != "rejected"
printed = False ]
for classification in elt["classes"]:
if classification["ml_class"]["name"] in self.accepted_classes: should_accept = True
if self.should_filter_printed: if self.should_filter_by_class:
if not printed: # at first filter to only have elements with accepted classes
accepted_zones.append(elt["zone"]["id"]) should_accept = False
else: for classification in elem_classes:
accepted_zones.append(elt["zone"]["id"]) class_name = classification["ml_class"]["name"]
if class_name in self.accepted_classes:
should_accept = True
break
if not should_accept:
continue
if self.should_filter_by_style:
style_counts = Counter()
for classification in elem_classes:
class_name = classification["ml_class"]["name"]
if class_name in STYLE_CLASSES:
style_counts[class_name] += 1
print("STYLE", style_counts)
if len(style_counts) == 0:
# no handwritten or typewritten found, so other
found_class = Style.other
elif len(style_counts) == 1:
found_class = list(style_counts.keys())[0]
found_class = Style(found_class)
else:
raise ValueError(f"Multiple style classes on the same element! {elt['id']} - {elem_classes}")
if found_class == self.style:
accepted_zones.append(elt["zone"]["id"])
else:
accepted_zones.append(elt["zone"]["id"])
logger.info( logger.info(
"Number of accepted zone for page {} : {}".format( "Number of accepted zone for page {} : {}".format(
page_id, len(accepted_zones) page_id, len(accepted_zones)
...@@ -197,6 +237,7 @@ class HTRDataGenerator: ...@@ -197,6 +237,7 @@ class HTRDataGenerator:
) )
raise e raise e
def _validate_transcriptions(self, page_id: str, lines: List[TranscriptionData]): def _validate_transcriptions(self, page_id: str, lines: List[TranscriptionData]):
if not lines: if not lines:
return return
...@@ -256,7 +297,7 @@ class HTRDataGenerator: ...@@ -256,7 +297,7 @@ class HTRDataGenerator:
): ):
continue continue
if ( if (
self.should_filter_by_class (self.should_filter_by_class or self.should_filter_by_style)
and res["element"]["zone"]["id"] not in accepted_zones and res["element"]["zone"]["id"] not in accepted_zones
): ):
continue continue
...@@ -362,7 +403,7 @@ class HTRDataGenerator: ...@@ -362,7 +403,7 @@ class HTRDataGenerator:
cv2.imwrite(f"{self.out_line_img_dir}/{page_id}_{i}.jpg", line_img) cv2.imwrite(f"{self.out_line_img_dir}/{page_id}_{i}.jpg", line_img)
def extract_lines(self, page_id: str, image_data: dict): def extract_lines(self, page_id: str, image_data: dict):
if self.should_filter_by_class: if self.should_filter_by_class or self.should_filter_by_style:
accepted_zones = self.get_accepted_zones(page_id) accepted_zones = self.get_accepted_zones(page_id)
else: else:
accepted_zones = [] accepted_zones = []
...@@ -815,9 +856,11 @@ def create_parser(): ...@@ -815,9 +856,11 @@ def create_parser():
) )
parser.add_argument( parser.add_argument(
"--filter_printed", "--style",
action="store_true", type=lambda x: Style[x.lower()],
help="Filter lines annotated as printed", default=None,
help=f"Filter line images by style class. 'other' corresponds to line elements that "
f"have neither handwritten or typewritten class : {[s.name for s in Style]}",
) )
parser.add_argument( parser.add_argument(
...@@ -857,6 +900,11 @@ def main(): ...@@ -857,6 +900,11 @@ def main():
if not args.dataset_name and not args.split_only and not args.format == "kraken": if not args.dataset_name and not args.split_only and not args.format == "kraken":
parser.error("--dataset_name must be specified (unless --split-only)") parser.error("--dataset_name must be specified (unless --split-only)")
if args.style and args.accepted_classes:
if any(c in args.accepted_classes for c in STYLE_CLASSES):
parser.error(f"--style class values ({STYLE_CLASSES}) shouldn't be in the accepted_classes list "
"if both --style and --accepted_classes are used together.")
logger.info(f"ARGS {args} \n") logger.info(f"ARGS {args} \n")
api_client = create_api_client(args.cache_dir) api_client = create_api_client(args.cache_dir)
...@@ -869,7 +917,7 @@ def main(): ...@@ -869,7 +917,7 @@ def main():
grayscale=args.grayscale, grayscale=args.grayscale,
extraction=args.extraction_mode, extraction=args.extraction_mode,
accepted_classes=args.accepted_classes, accepted_classes=args.accepted_classes,
filter_printed=args.filter_printed, style=args.style,
skip_vertical_lines=args.skip_vertical_lines, skip_vertical_lines=args.skip_vertical_lines,
transcription_type=args.transcription_type, transcription_type=args.transcription_type,
accepted_worker_version_ids=args.accepted_worker_version_ids, accepted_worker_version_ids=args.accepted_worker_version_ids,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment