Skip to content
Snippets Groups Projects

Implement extraction command

Merged Yoann Schneider requested to merge implement-extraction-command into main
9 files
+ 8
1753
Compare changes
  • Side-by-side
  • Inline
Files
9
import os
import shutil
import tarfile
import pickle
import re
from PIL import Image
import numpy as np
class DatasetFormatter:
"""
Global pipeline/functions for dataset formatting
"""
def __init__(self, dataset_name, level, extra_name="", set_names=["train", "valid", "test"]):
self.dataset_name = dataset_name
self.level = level
self.set_names = set_names
self.target_fold_path = os.path.join(
"Datasets", "formatted", "{}_{}{}".format(dataset_name, level, extra_name))
self.map_datasets_files = dict()
self.extract_with_dirname = False
def format(self):
self.init_format()
self.map_datasets_files[self.dataset_name][self.level]["format_function"]()
self.end_format()
def init_format(self):
"""
Load and extracts needed files
"""
os.makedirs(self.target_fold_path, exist_ok=True)
for set_name in self.set_names:
os.makedirs(os.path.join(self.target_fold_path, set_name), exist_ok=True)
class OCRDatasetFormatter(DatasetFormatter):
"""
Specific pipeline/functions for OCR/HTR dataset formatting
"""
def __init__(self, source_dataset, level, extra_name="", set_names=["train", "valid", "test"]):
super(OCRDatasetFormatter, self).__init__(source_dataset, level, extra_name, set_names)
self.charset = set()
self.gt = dict()
for set_name in set_names:
self.gt[set_name] = dict()
def format_text_label(self, label):
"""
Remove extra space or line break characters
"""
temp = re.sub("(\n)+", '\n', label)
return re.sub("( )+", ' ', temp).strip(" \n")
def load_resize_save(self, source_path, target_path):
"""
Load image, apply resolution modification and save it
"""
shutil.copyfile(source_path, target_path)
def resize(self, img, source_dpi, target_dpi):
"""
Apply resolution modification to image
"""
if source_dpi == target_dpi:
return img
if isinstance(img, np.ndarray):
h, w = img.shape[:2]
img = Image.fromarray(img)
else:
w, h = img.size
ratio = target_dpi / source_dpi
img = img.resize((int(w*ratio), int(h*ratio)), Image.BILINEAR)
return np.array(img)
def end_format(self):
"""
Save label and charset files
"""
with open(os.path.join(self.target_fold_path, "labels.pkl"), "wb") as f:
pickle.dump({
"ground_truth": self.gt,
"charset": sorted(list(self.charset)),
}, f)
with open(os.path.join(self.target_fold_path, "charset.pkl"), "wb") as f:
pickle.dump(sorted(list(self.charset)), f)
Loading