diff --git a/dan/manager/metrics.py b/dan/manager/metrics.py index c9e3a64203008e833366274c7df8335cd9883e08..36d3a29f489dd30533e46cc4b732b3cb36c81d33 100644 --- a/dan/manager/metrics.py +++ b/dan/manager/metrics.py @@ -127,7 +127,13 @@ class MetricManager: ) if output: display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"]) - elif metric_name in ["loss", "loss_ctc", "loss_ce", "syn_max_lines"]: + elif metric_name in [ + "loss", + "loss_ctc", + "loss_ce", + "syn_max_lines", + "syn_prob_lines", + ]: value = np.average( self.epoch_metrics[metric_name], weights=np.array(self.epoch_metrics["nb_samples"]), @@ -182,6 +188,7 @@ class MetricManager: "loss_ce", "loss", "syn_max_lines", + "syn_prob_lines", ]: metrics[metric_name] = [ values[metric_name], diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index 199cf5dcced64738942c65fc4c14bcc49cbd0a66..775fce6341a7e808bc2f4fef93f0897a569125b6 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -9,10 +9,12 @@ import torch from fontTools.ttLib import TTFont from PIL import Image, ImageDraw, ImageFont +from dan import logger from dan.manager.dataset import DatasetManager, GenericDataset, apply_preprocessing from dan.ocr.utils import LM_str_to_ind from dan.utils import ( pad_image, + pad_image_width_random, pad_image_width_right, pad_images, pad_sequences_1D, @@ -37,17 +39,10 @@ class OCRDatasetManager(DatasetManager): if ( "synthetic_data" in self.params["config"] and self.params["config"]["synthetic_data"] - and "config" in self.params["config"]["synthetic_data"] ): - self.char_only_set = self.charset.copy() - for token in [ - "\n", - ]: - if token in self.char_only_set: - self.char_only_set.remove(token) - self.params["config"]["synthetic_data"]["config"][ - "valid_fonts" - ] = get_valid_fonts(self.char_only_set) + self.synthetic_data = self.params["config"]["synthetic_data"] + if "config" in self.synthetic_data: + self.synthetic_data["config"]["valid_fonts"] = self.get_valid_fonts() if "new_tokens" in params: self.charset = sorted( @@ -108,6 +103,34 @@ class OCRDatasetManager(DatasetManager): [s["img"].shape[1] for s in self.train_dataset.samples] ) + def get_valid_fonts(self): + """ + Select fonts that are compatible with the alphabet + """ + font_path = self.synthetic_data["font_path"] + alphabet = self.charset.copy() + special_chars = ["\n"] + alphabet = [char for char in alphabet if char not in special_chars] + valid_fonts = list() + for fold_detail in os.walk(font_path): + if fold_detail[2]: + for font_name in fold_detail[2]: + if ".ttf" not in font_name: + continue + font_path = os.path.join(fold_detail[0], font_name) + to_add = True + if alphabet is not None: + for char in alphabet: + if not char_in_font(char, font_path): + to_add = False + break + if to_add: + valid_fonts.append(font_path) + else: + valid_fonts.append(font_path) + logger.info(f"Found {len(valid_fonts)} fonts.") + return valid_fonts + class OCRDataset(GenericDataset): """ @@ -123,6 +146,13 @@ class OCRDataset(GenericDataset): ) self.collate_function = OCRCollateFunction self.synthetic_id = 0 + if ( + "synthetic_data" in self.params["config"] + and self.params["config"]["synthetic_data"] + ): + self.synthetic_data = self.params["config"]["synthetic_data"] + else: + self.synthetic_data = None def __getitem__(self, idx): sample = copy.deepcopy(self.samples[idx]) @@ -281,67 +311,75 @@ class OCRDataset(GenericDataset): return sample def generate_synthetic_data(self, sample): - config = self.params["config"]["synthetic_data"] - - if not (config["init_proba"] == config["end_proba"] == 1): - nb_samples = self.training_info["step"] * self.params["batch_size"] - if config["start_scheduler_at_max_line"]: - max_step = config["num_steps_proba"] - current_step = max( - 0, - min( - nb_samples - - config["curr_step"] - * (config["max_nb_lines"] - config["min_nb_lines"]), - max_step, - ), - ) - proba = ( - config["init_proba"] - if self.get_syn_max_lines() < config["max_nb_lines"] - else config["proba_scheduler_function"]( - config["init_proba"], - config["end_proba"], - current_step, - max_step, - ) - ) - else: - proba = config["proba_scheduler_function"]( - config["init_proba"], - config["end_proba"], - min(nb_samples, config["num_steps_proba"]), - config["num_steps_proba"], - ) - if rand() > proba: - return sample - - if "mode" in config and config["mode"] == "line_hw_to_printed": + proba = self.get_syn_proba_lines() + if rand() > proba: + return sample + if ( + "mode" in self.synthetic_data + and self.synthetic_data["mode"] == "line_hw_to_printed" + ): sample["img"] = self.generate_typed_text_line_image(sample["label"]) return sample - return self.generate_synthetic_page_sample() def get_syn_max_lines(self): - config = self.params["config"]["synthetic_data"] - if config["curriculum"]: + if self.synthetic_data["curriculum"]: nb_samples = self.training_info["step"] * self.params["batch_size"] max_nb_lines = min( - config["max_nb_lines"], - (nb_samples - config["curr_start"]) // config["curr_step"] + 1, + self.synthetic_data["max_nb_lines"], + (nb_samples - self.synthetic_data["curr_start"]) + // self.synthetic_data["curr_step"] + + 1, + ) + return max(self.synthetic_data["min_nb_lines"], max_nb_lines) + return self.synthetic_data["max_nb_lines"] + + def get_syn_proba_lines(self): + if self.synthetic_data["init_proba"] == self.synthetic_data["end_proba"]: + return self.synthetic_data["init_proba"] + nb_samples = self.training_info["step"] * self.params["batch_size"] + if self.synthetic_data["start_scheduler_at_max_line"]: + max_step = self.synthetic_data["num_steps_proba"] + current_step = max( + 0, + min( + nb_samples + - self.synthetic_data["curr_step"] + * ( + self.synthetic_data["max_nb_lines"] + - self.synthetic_data["min_nb_lines"] + ), + max_step, + ), + ) + proba = ( + self.synthetic_data["init_proba"] + if self.get_syn_max_lines() < self.synthetic_data["max_nb_lines"] + else self.synthetic_data["proba_scheduler_function"]( + self.synthetic_data["init_proba"], + self.synthetic_data["end_proba"], + current_step, + max_step, + ) + ) + else: + proba = self.synthetic_data["proba_scheduler_function"]( + self.synthetic_data["init_proba"], + self.synthetic_data["end_proba"], + min(nb_samples, self.synthetic_data["num_steps_proba"]), + self.synthetic_data["num_steps_proba"], ) - return max(config["min_nb_lines"], max_nb_lines) - return config["max_nb_lines"] + return proba def generate_synthetic_page_sample(self): - config = self.params["config"]["synthetic_data"] max_nb_lines_per_page = self.get_syn_max_lines() crop = ( - config["crop_curriculum"] and max_nb_lines_per_page < config["max_nb_lines"] + self.synthetic_data["crop_curriculum"] + and max_nb_lines_per_page < self.synthetic_data["max_nb_lines"] ) sample = {"name": "synthetic_data_{}".format(self.synthetic_id), "path": None} self.synthetic_id += 1 - nb_pages = 2 if "double" in config["dataset_level"] else 1 + nb_pages = 2 if "double" in self.synthetic_data["dataset_level"] else 1 background_sample = copy.deepcopy(self.samples[randint(0, len(self))]) pages = list() backgrounds = list() @@ -350,7 +388,7 @@ class OCRDataset(GenericDataset): page_width = w // 2 if nb_pages == 2 else w for i in range(nb_pages): nb_lines_per_page = randint( - config["min_nb_lines"], max_nb_lines_per_page + 1 + self.synthetic_data["min_nb_lines"], max_nb_lines_per_page + 1 ) background = ( np.ones((h, page_width, c), dtype=background_sample["img"].dtype) * 255 @@ -386,7 +424,19 @@ class OCRDataset(GenericDataset): ) ) else: - raise NotImplementedError + # Get a page-level transcription and split it by lines + texts = self.samples[randint(0, len(self))]["label"].split("\n") + # Select some lines to be generated + n_lines = min(len(texts), nb_lines_per_page) + i = randint(0, len(texts) - n_lines + 1) + texts = texts[i : i + n_lines] + # Generate the synthetic document (of n_lines) + pages.append( + self.generate_typed_text_paragraph_image( + texts=texts, + same_font_size=True, + ) + ) if nb_pages == 1: sample["img"] = pages[0][0] @@ -419,9 +469,79 @@ class OCRDataset(GenericDataset): return sample def generate_typed_text_line_image(self, text): - return generate_typed_text_line_image( - text, self.params["config"]["synthetic_data"]["config"] - ) + return generate_typed_text_line_image(text, self.synthetic_data["config"]) + + def generate_typed_text_paragraph_image( + self, texts, padding_value=255, max_pad_left_ratio=0.1, same_font_size=False + ): + """ + Generate a synthetic paragraph from a list of texts where each line is generated with a different font. + """ + if same_font_size: + images = list() + txt_color = self.synthetic_data["config"]["text_color_default"] + bg_color = self.synthetic_data["config"]["background_color_default"] + font_size = randint( + self.synthetic_data["config"]["font_size_min"], + self.synthetic_data["config"]["font_size_max"] + 1, + ) + for text in texts: + font_path = self.synthetic_data["config"]["valid_fonts"][ + randint(0, len(self.synthetic_data["config"]["valid_fonts"])) + ] + fnt = ImageFont.truetype(font_path, font_size) + text_width, text_height = fnt.getsize(text) + padding_top = get_random_padding( + self.synthetic_data["config"]["padding_top_ratio_min"], + self.synthetic_data["config"]["padding_top_ratio_max"], + text_height, + ) + padding_bottom = get_random_padding( + self.synthetic_data["config"]["padding_bottom_ratio_min"], + self.synthetic_data["config"]["padding_bottom_ratio_max"], + text_height, + ) + padding_left = get_random_padding( + self.synthetic_data["config"]["padding_left_ratio_min"], + self.synthetic_data["config"]["padding_left_ratio_max"], + text_width, + ) + padding_right = get_random_padding( + self.synthetic_data["config"]["padding_right_ratio_min"], + self.synthetic_data["config"]["padding_right_ratio_max"], + text_width, + ) + padding = [padding_top, padding_bottom, padding_left, padding_right] + images.append( + generate_typed_text_line_image_from_params( + text, + fnt, + bg_color, + txt_color, + self.synthetic_data["config"]["color_mode"], + padding, + ) + ) + else: + images = [generate_typed_text_line_image(t) for t in texts] + + max_width = max([img.shape[1] for img in images]) + padded_images = [ + pad_image_width_random( + img, + max_width, + padding_value=padding_value, + max_pad_left_ratio=max_pad_left_ratio, + ) + for img in images + ] + label = { + "sem": "\n".join(texts), + "begin": "\n".join(texts), + "raw": "\n".join(texts), + } + # image, label, n_col + return [np.concatenate(padded_images, axis=0), label, 1] class OCRCollateFunction: @@ -555,6 +675,13 @@ class OCRCollateFunction: return formatted_batch_data +def get_random_padding(min_ratio, max_ratio, text_size): + """ + Compute random padding value as a ratio of text width or height + """ + return int(rand_uniform(min_ratio, max_ratio) * text_size) + + def generate_typed_text_line_image( text, config, bg_color=(255, 255, 255), txt_color=(0, 0, 0) ): @@ -570,25 +697,25 @@ def generate_typed_text_line_image( fnt = ImageFont.truetype(font_path, font_size) text_width, text_height = fnt.getsize(text) - padding_top = int( - rand_uniform(config["padding_top_ratio_min"], config["padding_top_ratio_max"]) - * text_height + padding_top = get_random_padding( + config["padding_top_ratio_min"], + config["padding_top_ratio_max"], + text_height, ) - padding_bottom = int( - rand_uniform( - config["padding_bottom_ratio_min"], config["padding_bottom_ratio_max"] - ) - * text_height + padding_bottom = get_random_padding( + config["padding_bottom_ratio_min"], + config["padding_bottom_ratio_max"], + text_height, ) - padding_left = int( - rand_uniform(config["padding_left_ratio_min"], config["padding_left_ratio_max"]) - * text_width + padding_left = get_random_padding( + config["padding_left_ratio_min"], + config["padding_left_ratio_max"], + text_width, ) - padding_right = int( - rand_uniform( - config["padding_right_ratio_min"], config["padding_right_ratio_max"] - ) - * text_width + padding_right = get_random_padding( + config["padding_right_ratio_min"], + config["padding_right_ratio_max"], + text_width, ) padding = [padding_top, padding_bottom, padding_left, padding_right] return generate_typed_text_line_image_from_params( @@ -616,24 +743,3 @@ def char_in_font(unicode_char, font_path): if ord(unicode_char) in cmap.cmap: return True return False - - -def get_valid_fonts(alphabet=None): - valid_fonts = list() - for fold_detail in os.walk("../../../Fonts"): - if fold_detail[2]: - for font_name in fold_detail[2]: - if ".ttf" not in font_name: - continue - font_path = os.path.join(fold_detail[0], font_name) - to_add = True - if alphabet is not None: - for char in alphabet: - if not char_in_font(char, font_path): - to_add = False - break - if to_add: - valid_fonts.append(font_path) - else: - valid_fonts.append(font_path) - return valid_fonts diff --git a/dan/manager/training.py b/dan/manager/training.py index 682e3091735c01244e9249568675eb1ad86c53d0..41ba05cb60b6da7c326c71f4e98f139fbeaf05ec 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -1160,6 +1160,9 @@ class Manager(OCRManager): "syn_max_lines": self.dataset.train_dataset.get_syn_max_lines() if self.params["dataset_params"]["config"]["synthetic_data"] else 0, + "syn_prob_lines": self.dataset.train_dataset.get_syn_proba_lines() + if self.params["dataset_params"]["config"]["synthetic_data"] + else 0, } return values diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index a2b5e5684580a4ee8fb9a27acb7c305215812a65..b173ba99e3d7d90988c01b7f8d185a8fa9343846 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -10,7 +10,7 @@ from dan.decoder import GlobalHTADecoder from dan.manager.ocr import OCRDataset, OCRDatasetManager from dan.manager.training import Manager from dan.models import FCN_Encoder -from dan.schedulers import exponential_dropout_scheduler +from dan.schedulers import exponential_dropout_scheduler, linear_scheduler from dan.transforms import aug_config @@ -32,7 +32,7 @@ def train_and_test(rank, params): model.params["training_params"]["load_epoch"] = "best" model.load_model() - metrics = ["cer", "wer", "time", "map_cer", "loer"] + metrics = ["cer", "wer", "time"] for dataset_name in params["dataset_params"]["datasets"].keys(): for set_name in ["test", "val", "train"]: model.predict( @@ -46,9 +46,9 @@ def train_and_test(rank, params): def run(): - dataset_name = "simara" - dataset_level = "page" - dataset_variant = "_sem" + dataset_name = "esposalles" + dataset_level = "record" + dataset_variant = "" params = { "dataset_params": { @@ -90,40 +90,41 @@ def run(): }, ], "augmentation": aug_config(0.9, 0.1), - "synthetic_data": None, - # "synthetic_data": { - # "init_proba": 0.9, # begin proba to generate synthetic document - # "end_proba": 0.2, # end proba to generate synthetic document - # "num_steps_proba": 200000, # linearly decrease the percent of synthetic document from 90% to 20% through 200000 samples - # "proba_scheduler_function": linear_scheduler, # decrease proba rate linearly - # "start_scheduler_at_max_line": True, # start decreasing proba only after curriculum reach max number of lines - # "dataset_level": dataset_level, - # "curriculum": True, # use curriculum learning (slowly increase number of lines per synthetic samples) - # "crop_curriculum": True, # during curriculum learning, crop images under the last text line - # "curr_start": 0, # start curriculum at iteration - # "curr_step": 10000, # interval to increase the number of lines for curriculum learning - # "min_nb_lines": 1, # initial number of lines for curriculum learning - # "max_nb_lines": max_nb_lines[dataset_name], # maximum number of lines for curriculum learning - # "padding_value": 255, - # # config for synthetic line generation - # "config": { - # "background_color_default": (255, 255, 255), - # "background_color_eps": 15, - # "text_color_default": (0, 0, 0), - # "text_color_eps": 15, - # "font_size_min": 35, - # "font_size_max": 45, - # "color_mode": "RGB", - # "padding_left_ratio_min": 0.00, - # "padding_left_ratio_max": 0.05, - # "padding_right_ratio_min": 0.02, - # "padding_right_ratio_max": 0.2, - # "padding_top_ratio_min": 0.02, - # "padding_top_ratio_max": 0.1, - # "padding_bottom_ratio_min": 0.02, - # "padding_bottom_ratio_max": 0.1, - # }, - # } + # "synthetic_data": None, + "synthetic_data": { + "init_proba": 0.9, # begin proba to generate synthetic document + "end_proba": 0.2, # end proba to generate synthetic document + "num_steps_proba": 200000, # linearly decrease the percent of synthetic document from 90% to 20% through 200000 samples + "proba_scheduler_function": linear_scheduler, # decrease proba rate linearly + "start_scheduler_at_max_line": True, # start decreasing proba only after curriculum reach max number of lines + "dataset_level": dataset_level, + "curriculum": True, # use curriculum learning (slowly increase number of lines per synthetic samples) + "crop_curriculum": True, # during curriculum learning, crop images under the last text line + "curr_start": 0, # start curriculum at iteration + "curr_step": 10000, # interval to increase the number of lines for curriculum learning + "min_nb_lines": 1, # initial number of lines for curriculum learning + "max_nb_lines": 4, # maximum number of lines for curriculum learning + "padding_value": 255, + "font_path": "fonts/", + # config for synthetic line generation + "config": { + "background_color_default": (255, 255, 255), + "background_color_eps": 15, + "text_color_default": (0, 0, 0), + "text_color_eps": 15, + "font_size_min": 35, + "font_size_max": 45, + "color_mode": "RGB", + "padding_left_ratio_min": 0.00, + "padding_left_ratio_max": 0.05, + "padding_right_ratio_min": 0.02, + "padding_right_ratio_max": 0.2, + "padding_top_ratio_min": 0.02, + "padding_top_ratio_max": 0.1, + "padding_bottom_ratio_min": 0.02, + "padding_bottom_ratio_max": 0.1, + }, + }, }, }, "model_params": { @@ -134,8 +135,18 @@ def run(): # "transfer_learning": None, "transfer_learning": { # model_name: [state_dict_name, checkpoint_path, learnable, strict] - "encoder": ["encoder", "dan_rimes_page.pt", True, True], - "decoder": ["decoder", "dan_rimes_page.pt", True, False], + "encoder": [ + "encoder", + "pretrained_models/dan_rimes_page.pt", + True, + True, + ], + "decoder": [ + "decoder", + "pretrained_models/dan_rimes_page.pt", + True, + False, + ], }, "transfered_charset": True, # Transfer learning of the decision layer based on charset of the line HTR model "additional_tokens": 1, # for decision layer = [<eot>, ], only for transferred charset @@ -163,7 +174,7 @@ def run(): }, }, "training_params": { - "output_folder": "dan_simara_page", # folder name for checkpoint and results + "output_folder": "dan_esposalles_record", # folder name for checkpoint and results "max_nb_epochs": 50000, # maximum number of epochs before to stop "max_training_time": 3600 * 24 @@ -198,11 +209,11 @@ def run(): "cer", "wer", "syn_max_lines", + "syn_prob_lines", ], # Metrics name for training "eval_metrics": [ "cer", "wer", - "map_cer", ], # Metrics name for evaluation on validation set during training "force_cpu": False, # True for debug purposes "max_char_prediction": 1000, # max number of token prediction diff --git a/dan/utils.py b/dan/utils.py index 764bff9014689c50ae51126cab4d4447539525d5..8d6f4b4f4407b18af58ba09265b5d6639ccbf858 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -160,3 +160,22 @@ def pad_image_width_right(img, new_width, padding_value): pad_right = np.ones((h, pad_width, c), dtype=img.dtype) * padding_value img = np.concatenate([img, pad_right], axis=1) return img + + +def pad_image_width_random(img, new_width, padding_value, max_pad_left_ratio=1): + """ + Randomly pad img to left and right sides with padding value to reach new_width as width + """ + h, w, c = img.shape + pad_width = max((new_width - w), 0) + max_pad_left = int(max_pad_left_ratio * pad_width) + pad_left = ( + randint(0, min(pad_width, max_pad_left)) + if pad_width != 0 and max_pad_left > 0 + else 0 + ) + pad_right = pad_width - pad_left + pad_left = np.ones((h, pad_left, c), dtype=img.dtype) * padding_value + pad_right = np.ones((h, pad_right, c), dtype=img.dtype) * padding_value + img = np.concatenate([pad_left, img, pad_right], axis=1) + return img