From c5e4f26e610f64177099a1796365128fc577d4ff Mon Sep 17 00:00:00 2001
From: Martin <maarand@teklia.com>
Date: Tue, 15 Feb 2022 17:23:43 +0100
Subject: [PATCH] filter text_lines by style class - handwritten, typewritten
 or neither (other)

---
 kaldi_data_generator/main.py | 86 ++++++++++++++++++++++++++++--------
 1 file changed, 67 insertions(+), 19 deletions(-)

diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py
index 4123088..cf204cf 100644
--- a/kaldi_data_generator/main.py
+++ b/kaldi_data_generator/main.py
@@ -64,6 +64,15 @@ class Extraction(Enum):
     skew_min_area_rect: int = 6
 
 
+class Style(Enum):
+    handwritten: str = "handwritten"
+    typewritten: str = "typewritten"
+    other: str = "other"
+
+
+STYLE_CLASSES = [s.name for s in [Style.handwritten, Style.typewritten]]
+
+
 class HTRDataGenerator:
     def __init__(
         self,
@@ -73,7 +82,7 @@ class HTRDataGenerator:
         grayscale=True,
         extraction=Extraction.boundingRect,
         accepted_classes=None,
-        filter_printed=False,
+        style=None,
         skip_vertical_lines=False,
         accepted_worker_version_ids=None,
         transcription_type=TEXT_LINE,
@@ -96,7 +105,8 @@ class HTRDataGenerator:
         self.should_filter_by_class = bool(self.accepted_classes)
         self.accepted_worker_version_ids = accepted_worker_version_ids
         self.should_filter_by_worker = bool(self.accepted_worker_version_ids)
-        self.should_filter_printed = filter_printed
+        self.style = style
+        self.should_filter_by_style = bool(self.style)
         self.transcription_type = transcription_type
         self.skip_vertical_lines = skip_vertical_lines
         self.skipped_pages_count = 0
@@ -174,17 +184,47 @@ class HTRDataGenerator:
             for elt in self.api_client.cached_paginate(
                 "ListElementChildren", id=page_id, with_classes=True
             ):
-                printed = True
-                for classification in elt["classes"]:
-                    if classification["ml_class"]["name"] == "handwritten":
-                        printed = False
-                for classification in elt["classes"]:
-                    if classification["ml_class"]["name"] in self.accepted_classes:
-                        if self.should_filter_printed:
-                            if not printed:
-                                accepted_zones.append(elt["zone"]["id"])
-                        else:
-                            accepted_zones.append(elt["zone"]["id"])
+                elem_classes = [c
+                                for c in elt["classes"]
+                                if c["state"] != "rejected"
+                                ]
+
+                should_accept = True
+                if self.should_filter_by_class:
+                    # at first filter to only have elements with accepted classes
+                    should_accept = False
+                    for classification in elem_classes:
+                        class_name = classification["ml_class"]["name"]
+                        if class_name in self.accepted_classes:
+                            should_accept = True
+                            break
+
+                if not should_accept:
+                    continue
+
+                if self.should_filter_by_style:
+                    style_counts = Counter()
+                    for classification in elem_classes:
+                        class_name = classification["ml_class"]["name"]
+                        if class_name in STYLE_CLASSES:
+                            style_counts[class_name] += 1
+
+                    print("STYLE", style_counts)
+
+                    if len(style_counts) == 0:
+                        # no handwritten or typewritten found, so other
+                        found_class = Style.other
+                    elif len(style_counts) == 1:
+                        found_class = list(style_counts.keys())[0]
+                        found_class = Style(found_class)
+                    else:
+                        raise ValueError(f"Multiple style classes on the same element! {elt['id']} - {elem_classes}")
+
+                    if found_class == self.style:
+                        accepted_zones.append(elt["zone"]["id"])
+                else:
+                    accepted_zones.append(elt["zone"]["id"])
+
             logger.info(
                 "Number of accepted zone for page {} : {}".format(
                     page_id, len(accepted_zones)
@@ -197,6 +237,7 @@ class HTRDataGenerator:
             )
             raise e
 
+
     def _validate_transcriptions(self, page_id: str, lines: List[TranscriptionData]):
         if not lines:
             return
@@ -256,7 +297,7 @@ class HTRDataGenerator:
                 ):
                     continue
                 if (
-                    self.should_filter_by_class
+                    (self.should_filter_by_class or self.should_filter_by_style)
                     and res["element"]["zone"]["id"] not in accepted_zones
                 ):
                     continue
@@ -362,7 +403,7 @@ class HTRDataGenerator:
             cv2.imwrite(f"{self.out_line_img_dir}/{page_id}_{i}.jpg", line_img)
 
     def extract_lines(self, page_id: str, image_data: dict):
-        if self.should_filter_by_class:
+        if self.should_filter_by_class or self.should_filter_by_style:
             accepted_zones = self.get_accepted_zones(page_id)
         else:
             accepted_zones = []
@@ -815,9 +856,11 @@ def create_parser():
     )
 
     parser.add_argument(
-        "--filter_printed",
-        action="store_true",
-        help="Filter lines annotated as printed",
+        "--style",
+        type=lambda x: Style[x.lower()],
+        default=None,
+        help=f"Filter line images by style class. 'other' corresponds to line elements that "
+        f"have neither handwritten or typewritten class : {[s.name for s in Style]}",
     )
 
     parser.add_argument(
@@ -857,6 +900,11 @@ def main():
     if not args.dataset_name and not args.split_only and not args.format == "kraken":
         parser.error("--dataset_name must be specified (unless --split-only)")
 
+    if args.style and args.accepted_classes:
+        if any(c in args.accepted_classes for c in STYLE_CLASSES):
+            parser.error(f"--style class values ({STYLE_CLASSES}) shouldn't be in the accepted_classes list "
+                         "if both --style and --accepted_classes are used together.")
+
     logger.info(f"ARGS {args} \n")
 
     api_client = create_api_client(args.cache_dir)
@@ -869,7 +917,7 @@ def main():
             grayscale=args.grayscale,
             extraction=args.extraction_mode,
             accepted_classes=args.accepted_classes,
-            filter_printed=args.filter_printed,
+            style=args.style,
             skip_vertical_lines=args.skip_vertical_lines,
             transcription_type=args.transcription_type,
             accepted_worker_version_ids=args.accepted_worker_version_ids,
-- 
GitLab