From 4a76d32def89002963dfe55f78044c82788d8189 Mon Sep 17 00:00:00 2001
From: Martin <maarand@teklia.com>
Date: Mon, 4 Nov 2019 18:25:18 +0100
Subject: [PATCH] refactor, use class

---
 kaldi_data_generator.py | 258 +++++++++++++++++++++-------------------
 1 file changed, 137 insertions(+), 121 deletions(-)

diff --git a/kaldi_data_generator.py b/kaldi_data_generator.py
index 7e044c3..b7c6126 100644
--- a/kaldi_data_generator.py
+++ b/kaldi_data_generator.py
@@ -1,24 +1,21 @@
 #!/usr/bin/env python3
 # -*- coding: utf-8 -*-
 
-import glob
 import os
-from pathlib import Path
-
-import requests
+import random
 from io import BytesIO
-from PIL import Image
+from pathlib import Path
 
 import cv2
-
 import numpy as np
-
-import random
-
+import requests
+from PIL import Image
 from apistar.exceptions import ErrorResponse
 from arkindex import ArkindexClient, options_from_env
+
 api_client = ArkindexClient(**options_from_env())
 
+
 def download_image(url):
     '''
     Download an image and open it with Pillow
@@ -33,122 +30,141 @@ def download_image(url):
     # Preprocess the image and prepare it for classification
     image = Image.open(BytesIO(resp.content))
     print('Downloaded image {} - size={}x{}'.format(url,
-                                                          image.size[0],
-                                                          image.size[1]))
+                                                    image.size[0],
+                                                    image.size[1]))
 
     return image
 
+
 def write_file(file_name, content):
     with open(file_name, 'w') as f:
         f.write(content)
 
-def get_image(image_url, grayscale, out_dir):
-    out_full_img_dir = os.path.join(out_dir, 'full', page_id)
-    os.makedirs(out_full_img_dir, exist_ok=True)
-    out_full_img_path = os.path.join(out_full_img_dir, 'full.jpg')
-    if grayscale:
-        download_image(image_url).convert('L').save(
-                        out_full_img_path, format='jpeg')
-        img = cv2.imread(out_full_img_path, cv2.IMREAD_GRAYSCALE)
-    else:
-        download_image(image_url).save(
-                        out_full_img_path, format='jpeg')
-        img = cv2.imread(out_full_img_path)
-    return img
-    
-def extract_lines(page_id, grayscale=True, out_dir='/tmp'):
-    count = 0
-    line_bounding_rects = []
-    line_polygons = []
-    line_transcriptions = []
-    try:
-        for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'):
-            text = res['text']
-            if not text or not text.strip():
-                continue
-            line_transcriptions.append(text)
-            polygon = res['zone']['polygon']
-            line_polygons.append(polygon)
-            [x, y, w, h] = cv2.boundingRect(np.asarray(polygon))
-            line_bounding_rects.append([x, y, w, h])
-            count += 1
-    except ErrorResponse as e:
-        print("ListTranscriptions failed", e.status_code, e.title, e.content, page_id)
-        raise e
-
-    full_image_url = res['zone']['image']['s3_url']
-    
-    img = get_image(full_image_url, grayscale=grayscale, out_dir=out_dir)
-    
-    out_line_img_dir = os.path.join(out_dir, 'Lines', page_id)
-    os.makedirs(out_line_img_dir, exist_ok=True)
-    for i, [x, y, w, h] in enumerate(line_bounding_rects):
-        croped = img[y:y + h, x:x + w].copy()
-#        cv2.imwrite(f'{out_line_img_dir}/{i}.jpg', croped)
-        cv2.imwrite(f'{out_line_img_dir}_{i}.jpg', croped)
-    
-    out_line_text_dir = os.path.join(out_dir, 'Transcriptions', page_id)
-    os.makedirs(out_line_text_dir, exist_ok=True)
-    for i, text in enumerate(line_transcriptions):
-        write_file(f"{out_line_text_dir}_{i}.txt", text)
-#        write_file(f"{out_line_text_dir}/{i}.txt", text)
-
-
-split_train_ratio = 0.8
-split_test_ratio = 0.1
-split_val_ratio = 1 - split_train_ratio - split_test_ratio
-
-
-def page_level_split(line_ids):
-#    page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids})
-    page_ids = list({line_id for line_id in line_ids})
-    random.shuffle(page_ids)
-    page_count = len(page_ids)
-    
-    train_page_ids = page_ids[:round(page_count * split_train_ratio)]
-    page_ids = page_ids[round(page_count * split_train_ratio):]
-    
-    test_page_ids = page_ids[:round(page_count * split_test_ratio)]
-    page_ids = page_ids[round(page_count * split_test_ratio):]
-    
-    val_page_ids = page_ids
-    
-    page_dict = {page_id: TRAIN for page_id in train_page_ids}
-    page_dict.update({page_id: TEST for page_id in test_page_ids})
-    page_dict.update({page_id: VAL for page_id in val_page_ids})
-    return (train_page_ids, val_page_ids, test_page_ids), page_dict
-
-TRAIN,TEST,VAL = 0,1,2
-out_file_dict = {0 : 'Train', 1 : 'Test', 2 : 'Validation'}
-
-def create_partitions(line_ids, out_dir):
-    (train_page_ids, val_page_ids, test_page_ids), page_dict = page_level_split(line_ids)
-    datasets = [[] for i in range(3)]
-    for line_id in line_ids:
-        page_id = line_id
-        split_id = page_dict[page_id]
-        datasets[split_id].append(line_id)
-    
-    partitions_dir = os.path.join(out_dir, 'Partitions')
-    os.makedirs(partitions_dir, exist_ok=True)
-    for i, dataset in enumerate(datasets):
-        file_name = f"{partitions_dir}/{out_file_dict[i]}Lines.lst"
-        with open(file_name, 'w') as f:
-            f.write('\n'.join(dataset) + '\n')
-
-
-out_dir_base = '/tmp/foo2'
-
-#page_id = 'bf23cc96-f6b2-4182-923e-6c163db37eba'
-page_ids = ['bf23cc96-f6b2-4182-923e-6c163db37eba',
-            '7c51e648-370e-43b7-9340-3b1a17c13828',
-            '56521074-59f4-4173-bfc1-4b1384ff8139',]
-
-for page_id in page_ids:
-    extract_lines(page_id, out_dir=out_dir_base)
-
-
-lines_path = Path(f'{out_dir_base}/Lines')
-line_ids = [str(file.relative_to(lines_path).with_suffix('')) for file in lines_path.glob('**/*.jpg')]
-
-create_partitions(line_ids, out_dir=out_dir_base)
+
+TRAIN, TEST, VAL = 0, 1, 2
+out_file_dict = {0: 'Train', 1: 'Test', 2: 'Validation'}
+
+
+class KaldiDataGenerator:
+
+    def __init__(self, dataset_name='foo', out_dir_base='/tmp/kaldi_data', split_train_ratio=0.8, split_test_ratio=0.1,
+                 grayscale=True):
+        self.out_dir_base = out_dir_base
+        self.dataset_name = dataset_name
+        self.split_train_ratio = split_train_ratio
+        self.split_test_ratio = split_test_ratio
+        self.split_val_ratio = 1 - self.split_train_ratio - self.split_test_ratio
+        self.grayscale = grayscale
+
+    def get_image(self, image_url, page_id):
+        out_full_img_dir = os.path.join(self.out_dir_base, 'full', page_id)
+        os.makedirs(out_full_img_dir, exist_ok=True)
+        out_full_img_path = os.path.join(out_full_img_dir, 'full.jpg')
+        if self.grayscale:
+            download_image(image_url).convert('L').save(
+                out_full_img_path, format='jpeg')
+            img = cv2.imread(out_full_img_path, cv2.IMREAD_GRAYSCALE)
+        else:
+            download_image(image_url).save(
+                out_full_img_path, format='jpeg')
+            img = cv2.imread(out_full_img_path)
+        return img
+
+    def extract_lines(self, page_id):
+        count = 0
+        line_bounding_rects = []
+        line_polygons = []
+        line_transcriptions = []
+        try:
+            for res in api_client.paginate('ListTranscriptions', id=page_id, type='line'):
+                text = res['text']
+                if not text or not text.strip():
+                    continue
+                line_transcriptions.append(text)
+                polygon = res['zone']['polygon']
+                line_polygons.append(polygon)
+                [x, y, w, h] = cv2.boundingRect(np.asarray(polygon))
+                line_bounding_rects.append([x, y, w, h])
+                count += 1
+        except ErrorResponse as e:
+            print("ListTranscriptions failed", e.status_code, e.title, e.content, page_id)
+            raise e
+        print("C", count)
+        full_image_url = res['zone']['image']['s3_url']
+
+        img = self.get_image(full_image_url, page_id=page_id)
+
+        out_line_img_dir = os.path.join(self.out_dir_base, 'Lines', self.dataset_name, page_id)
+        os.makedirs(out_line_img_dir, exist_ok=True)
+        for i, [x, y, w, h] in enumerate(line_bounding_rects):
+            cropped = img[y:y + h, x:x + w].copy()
+            cv2.imwrite(f'{out_line_img_dir}_{i}.jpg', cropped)
+
+        out_line_text_dir = os.path.join(self.out_dir_base, 'Transcriptions', self.dataset_name, page_id)
+        os.makedirs(out_line_text_dir, exist_ok=True)
+        for i, text in enumerate(line_transcriptions):
+            write_file(f"{out_line_text_dir}_{i}.txt", text)
+
+    def page_level_split(self, line_ids):
+        page_ids = list({'_'.join(line_id.split('_')[:-1]) for line_id in line_ids})
+        # page_ids = list({line_id for line_id in line_ids})
+        random.shuffle(page_ids)
+        page_count = len(page_ids)
+
+        train_page_ids = page_ids[:round(page_count * self.split_train_ratio)]
+        page_ids = page_ids[round(page_count * self.split_train_ratio):]
+
+        test_page_ids = page_ids[:round(page_count * self.split_test_ratio)]
+        page_ids = page_ids[round(page_count * self.split_test_ratio):]
+
+        val_page_ids = page_ids
+
+        page_dict = {page_id: TRAIN for page_id in train_page_ids}
+        page_dict.update({page_id: TEST for page_id in test_page_ids})
+        page_dict.update({page_id: VAL for page_id in val_page_ids})
+        return page_dict
+
+    def create_partitions(self):
+        lines_path = Path(f'{self.out_dir_base}/Lines')
+        line_ids = [str(file.relative_to(lines_path).with_suffix('')) for file in lines_path.glob('**/*.jpg')]
+
+        page_dict = self.page_level_split(line_ids)
+        datasets = [[] for _ in range(3)]
+        for line_id in line_ids:
+            page_id = '_'.join(line_id.split('_')[:-1])
+            split_id = page_dict[page_id]
+            datasets[split_id].append(line_id)
+
+        partitions_dir = os.path.join(self.out_dir_base, 'Partitions')
+        os.makedirs(partitions_dir, exist_ok=True)
+        for i, dataset in enumerate(datasets):
+            file_name = f"{partitions_dir}/{out_file_dict[i]}Lines.lst"
+            write_file(file_name, '\n'.join(dataset) + '\n')
+
+    def run_pages(self, page_ids):
+        for page_id in page_ids:
+            print("P", page_id)
+            self.extract_lines(page_id)
+
+    def run_volumes(self, volume_ids):
+        for volume_id in volume_ids:
+            print("V", volume_id)
+            page_ids = [page['id'] for page in api_client.paginate('ListElementChildren', id=volume_id)]
+            self.run_pages(page_ids)
+
+
+example_page_ids = [
+    'bf23cc96-f6b2-4182-923e-6c163db37eba',
+    '7c51e648-370e-43b7-9340-3b1a17c13828',
+    '56521074-59f4-4173-bfc1-4b1384ff8139',
+]
+
+example_volume_ids = [
+    '8f4005e9-1921-47b0-be7b-e27c7fd29486',
+]
+
+kaldi_data_generator = KaldiDataGenerator()
+
+# kaldi_data_generator.run_page(example_page_ids)
+kaldi_data_generator.run_volumes(example_volume_ids)
+kaldi_data_generator.create_partitions()
-- 
GitLab