diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index d097eae31ee6a890e1d2163410f8d578777e8dea..eb45816742759acaa4da4308690a3cf12e679555 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -9,6 +9,7 @@ import torch from fontTools.ttLib import TTFont from PIL import Image, ImageDraw, ImageFont +from dan import logger from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing from dan.ocr.utils import LM_str_to_ind from dan.utils import ( @@ -40,16 +41,9 @@ class OCRDatasetManager(DatasetManager): and self.params["config"]["synthetic_data"] and "config" in self.params["config"]["synthetic_data"] ): - self.char_only_set = self.charset.copy() - for token in [ - "\n", - ]: - if token in self.char_only_set: - self.char_only_set.remove(token) - self.font_path = self.params["config"]["synthetic_data"]["font_path"] self.params["config"]["synthetic_data"]["config"][ "valid_fonts" - ] = get_valid_fonts(self.font_path, self.char_only_set) + ] = self.get_valid_fonts() if "new_tokens" in params: self.charset = sorted( @@ -110,6 +104,34 @@ class OCRDatasetManager(DatasetManager): [s["img"].shape[1] for s in self.train_dataset.samples] ) + def get_valid_fonts(self): + """ + Select fonts that are compatible with the alphabet + """ + font_path = self.params["config"]["synthetic_data"]["font_path"] + alphabet = self.charset.copy() + special_chars = ["\n"] + alphabet = [char for char in alphabet if char not in special_chars] + valid_fonts = list() + for fold_detail in os.walk(font_path): + if fold_detail[2]: + for font_name in fold_detail[2]: + if ".ttf" not in font_name: + continue + font_path = os.path.join(fold_detail[0], font_name) + to_add = True + if alphabet is not None: + for char in alphabet: + if not char_in_font(char, font_path): + to_add = False + break + if to_add: + valid_fonts.append(font_path) + else: + valid_fonts.append(font_path) + logger.info(f"Found {len(valid_fonts)} fonts.") + return valid_fonts + class OCRDataset(GenericDataset): """ @@ -311,10 +333,8 @@ class OCRDataset(GenericDataset): if config["init_proba"] == config["end_proba"]: return config["init_proba"] - else: nb_samples = self.training_info["step"] * self.params["batch_size"] - if config["start_scheduler_at_max_line"]: max_step = config["num_steps_proba"] current_step = max( @@ -450,6 +470,9 @@ class OCRDataset(GenericDataset): def generate_typed_text_paragraph_image( self, texts, padding_value=255, max_pad_left_ratio=0.1, same_font_size=False ): + """ + Generate a synthetic paragraph from a list of texts where each line is generated with a different font. + """ config = self.params["config"]["synthetic_data"]["config"] if same_font_size: images = list() @@ -513,7 +536,8 @@ class OCRDataset(GenericDataset): "begin": "\n".join(texts), "raw": "\n".join(texts), } - return [np.concatenate(padded_images, axis=0), label, 1] # image, label, n_col + # image, label, n_col + return [np.concatenate(padded_images, axis=0), label, 1] class OCRCollateFunction: @@ -708,25 +732,3 @@ def char_in_font(unicode_char, font_path): if ord(unicode_char) in cmap.cmap: return True return False - - -def get_valid_fonts(font_path, alphabet=None): - valid_fonts = list() - for fold_detail in os.walk(font_path): - if fold_detail[2]: - for font_name in fold_detail[2]: - if ".ttf" not in font_name: - continue - font_path = os.path.join(fold_detail[0], font_name) - to_add = True - if alphabet is not None: - for char in alphabet: - if not char_in_font(char, font_path): - to_add = False - break - if to_add: - valid_fonts.append(font_path) - else: - valid_fonts.append(font_path) - print(f"Found {len(valid_fonts)} fonts.") - return valid_fonts