From 338fda479cb51f349c56f0d80accaac08b2ca0ed Mon Sep 17 00:00:00 2001
From: Martin <maarand@teklia.com>
Date: Mon, 14 Sep 2020 17:31:43 +0200
Subject: [PATCH] support new transcriptions

---
 kaldi_data_generator.py | 40 +++++++++++++++++++++++++---------------
 1 file changed, 25 insertions(+), 15 deletions(-)

diff --git a/kaldi_data_generator.py b/kaldi_data_generator.py
index 736f715..6440910 100644
--- a/kaldi_data_generator.py
+++ b/kaldi_data_generator.py
@@ -94,7 +94,7 @@ class KaldiDataGenerator:
             img = cv2.imread(out_full_img_path)
         return img
 
-    def extract_lines(self, page_id: str):
+    def extract_lines(self, page_id: str, image_data: dict):
         count = 0
         lines = []
         try:
@@ -114,7 +114,7 @@ class KaldiDataGenerator:
                                     accepted_zones.append(elt['zone']['id'])
                     logger.info('Number of accepted zone for page {} : {}'.format(page_id,len(accepted_zones)))
 
-            for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'):
+            for res in api_client.paginate('ListTranscriptions', id=page_id, type='line', recursive=True):
                 if self.should_filter_by_slug and res['source']['slug'] not in self.accepted_slugs:
                     continue
 
@@ -124,7 +124,15 @@ class KaldiDataGenerator:
                 text = res['text']
                 if not text or not text.strip():
                     continue
-                polygon = np.asarray(res['zone']['polygon']).clip(0)
+
+                if res['zone']:
+                    polygon = res['zone']['polygon']
+                elif res['element']['zone']:
+                    polygon = res['element']['zone']['polygon']
+                else:
+                    raise ValueError(f"Data problem with polygon :: {res}")
+
+                polygon = np.asarray(polygon).clip(0)
                 [x, y, w, h] = cv2.boundingRect(polygon)
                 lines.append(((x, y, w, h), polygon, text))
                 count += 1
@@ -136,9 +144,9 @@ class KaldiDataGenerator:
             logger.info(f"Page {page_id} skipped, because it has no lines")
             return
 
-        full_image_url = res['zone']['image']['s3_url']
+        full_image_url = image_data['s3_url']
         if full_image_url is None:
-            full_image_url = res['zone']['image']['url'] + '/full/full/0/default.jpg'
+            full_image_url = image_data['url'] + '/full/full/0/default.jpg'
 
         img = self.get_image(full_image_url, page_id=page_id)
 
@@ -175,16 +183,18 @@ class KaldiDataGenerator:
         dst2 = bg + dst
         return dst2
 
-    def run_pages(self, page_ids: list):
-        for page_id in tqdm.tqdm(page_ids):
+    def run_pages(self, pages: list):
+        for page in tqdm.tqdm(pages):
+            page_id = page['id']
+            image_data = page['zone']['image']
             logger.debug(f"Page {page_id}")
-            self.extract_lines(page_id)
+            self.extract_lines(page_id, image_data)
 
     def run_volumes(self, volume_ids: list):
         for volume_id in tqdm.tqdm(volume_ids):
             logger.info(f"Volume {volume_id}")
-            page_ids = [page['id'] for page in api_client.paginate('ListElementChildren', id=volume_id)]
-            self.run_pages(page_ids)
+            pages = [page for page in api_client.paginate('ListElementChildren', id=volume_id)]
+            self.run_pages(pages)
 
     def run_folders(self, element_ids: list, volume_type: str):
         for elem_id in tqdm.tqdm(element_ids):
@@ -266,7 +276,7 @@ def create_parser():
     parser.add_argument('--train_ratio', type=float, default=0.8,
                         help='Ratio of pages to be used in train (between 0 and 1)')
     parser.add_argument('--test_ratio', type=float, default=0.1,
-                        help='Ratio of pages to be used in train (between 0 and 1 - train_ratio)')
+                        help='Ratio of pages to be used in test (between 0 and 1 - train_ratio)')
     parser.add_argument('-e', '--extraction_mode', type=lambda x: Extraction[x], default=Extraction.boundingRect,
                         help=f'Mode for extracting the line images: {[e.name for e in Extraction]}')
 
@@ -284,8 +294,8 @@ def create_parser():
                              'Elements of `volume_type` will be searched recursively in these folders')
     parser.add_argument('--volumes', nargs='*',
                         help='List of volume ids to be used, separated by spaces')
-    parser.add_argument('--pages', nargs='*',
-                        help='List of page ids to be used, separated by spaces')
+    # parser.add_argument('--pages', nargs='*',
+    #                     help='List of page ids to be used, separated by spaces')
     parser.add_argument('-v', '--volume_type', type=str, default='volume',
                         help='Volumes (1 level above page) may have a different name on corpora')
 
@@ -317,8 +327,8 @@ def main():
                                                split_train_ratio=args.train_ratio,
                                                split_test_ratio=args.test_ratio)
     # extract all the lines and transcriptions
-    if args.pages:
-        kaldi_data_generator.run_pages(args.pages)
+    # if args.pages:
+    #     kaldi_data_generator.run_pages(args.pages)
     if args.volumes:
         kaldi_data_generator.run_volumes(args.volumes)
     if args.folders:
-- 
GitLab