diff --git a/dan/manager/metrics.py b/dan/manager/metrics.py index c9e3a64203008e833366274c7df8335cd9883e08..36d3a29f489dd30533e46cc4b732b3cb36c81d33 100644 --- a/dan/manager/metrics.py +++ b/dan/manager/metrics.py @@ -127,7 +127,13 @@ class MetricManager: ) if output: display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"]) - elif metric_name in ["loss", "loss_ctc", "loss_ce", "syn_max_lines"]: + elif metric_name in [ + "loss", + "loss_ctc", + "loss_ce", + "syn_max_lines", + "syn_prob_lines", + ]: value = np.average( self.epoch_metrics[metric_name], weights=np.array(self.epoch_metrics["nb_samples"]), @@ -182,6 +188,7 @@ class MetricManager: "loss_ce", "loss", "syn_max_lines", + "syn_prob_lines", ]: metrics[metric_name] = [ values[metric_name], diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index 43d46894703d94e546196cf764ea582c266ccda0..d097eae31ee6a890e1d2163410f8d578777e8dea 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -284,9 +284,37 @@ class OCRDataset(GenericDataset): def generate_synthetic_data(self, sample): config = self.params["config"]["synthetic_data"] + proba = self.get_syn_proba_lines() - if not (config["init_proba"] == config["end_proba"] == 1): + if rand() > proba: + return sample + + if "mode" in config and config["mode"] == "line_hw_to_printed": + sample["img"] = self.generate_typed_text_line_image(sample["label"]) + return sample + + return self.generate_synthetic_page_sample() + + def get_syn_max_lines(self): + config = self.params["config"]["synthetic_data"] + if config["curriculum"]: nb_samples = self.training_info["step"] * self.params["batch_size"] + max_nb_lines = min( + config["max_nb_lines"], + (nb_samples - config["curr_start"]) // config["curr_step"] + 1, + ) + return max(config["min_nb_lines"], max_nb_lines) + return config["max_nb_lines"] + + def get_syn_proba_lines(self): + config = self.params["config"]["synthetic_data"] + + if config["init_proba"] == config["end_proba"]: + return config["init_proba"] + + else: + nb_samples = self.training_info["step"] * self.params["batch_size"] + if config["start_scheduler_at_max_line"]: max_step = config["num_steps_proba"] current_step = max( @@ -315,25 +343,7 @@ class OCRDataset(GenericDataset): min(nb_samples, config["num_steps_proba"]), config["num_steps_proba"], ) - if rand() > proba: - return sample - - if "mode" in config and config["mode"] == "line_hw_to_printed": - sample["img"] = self.generate_typed_text_line_image(sample["label"]) - return sample - - return self.generate_synthetic_page_sample() - - def get_syn_max_lines(self): - config = self.params["config"]["synthetic_data"] - if config["curriculum"]: - nb_samples = self.training_info["step"] * self.params["batch_size"] - max_nb_lines = min( - config["max_nb_lines"], - (nb_samples - config["curr_start"]) // config["curr_step"] + 1, - ) - return max(config["min_nb_lines"], max_nb_lines) - return config["max_nb_lines"] + return proba def generate_synthetic_page_sample(self): config = self.params["config"]["synthetic_data"] diff --git a/dan/manager/training.py b/dan/manager/training.py index 682e3091735c01244e9249568675eb1ad86c53d0..41ba05cb60b6da7c326c71f4e98f139fbeaf05ec 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -1160,6 +1160,9 @@ class Manager(OCRManager): "syn_max_lines": self.dataset.train_dataset.get_syn_max_lines() if self.params["dataset_params"]["config"]["synthetic_data"] else 0, + "syn_prob_lines": self.dataset.train_dataset.get_syn_proba_lines() + if self.params["dataset_params"]["config"]["synthetic_data"] + else 0, } return values diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index a4d3f5e88aaeba5fdb3872ecedb636be35a1077d..b173ba99e3d7d90988c01b7f8d185a8fa9343846 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -209,6 +209,7 @@ def run(): "cer", "wer", "syn_max_lines", + "syn_prob_lines", ], # Metrics name for training "eval_metrics": [ "cer",