Skip to content
Snippets Groups Projects
Commit 7a4af1eb authored by Solene Tarride's avatar Solene Tarride
Browse files

monitor probability of synthetic documents

parent 2ecb07cd
No related branches found
No related tags found
1 merge request!24Train with synthetic documents
......@@ -127,7 +127,13 @@ class MetricManager:
)
if output:
display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"])
elif metric_name in ["loss", "loss_ctc", "loss_ce", "syn_max_lines"]:
elif metric_name in [
"loss",
"loss_ctc",
"loss_ce",
"syn_max_lines",
"syn_prob_lines",
]:
value = np.average(
self.epoch_metrics[metric_name],
weights=np.array(self.epoch_metrics["nb_samples"]),
......@@ -182,6 +188,7 @@ class MetricManager:
"loss_ce",
"loss",
"syn_max_lines",
"syn_prob_lines",
]:
metrics[metric_name] = [
values[metric_name],
......
......@@ -284,9 +284,37 @@ class OCRDataset(GenericDataset):
def generate_synthetic_data(self, sample):
config = self.params["config"]["synthetic_data"]
proba = self.get_syn_proba_lines()
if not (config["init_proba"] == config["end_proba"] == 1):
if rand() > proba:
return sample
if "mode" in config and config["mode"] == "line_hw_to_printed":
sample["img"] = self.generate_typed_text_line_image(sample["label"])
return sample
return self.generate_synthetic_page_sample()
def get_syn_max_lines(self):
config = self.params["config"]["synthetic_data"]
if config["curriculum"]:
nb_samples = self.training_info["step"] * self.params["batch_size"]
max_nb_lines = min(
config["max_nb_lines"],
(nb_samples - config["curr_start"]) // config["curr_step"] + 1,
)
return max(config["min_nb_lines"], max_nb_lines)
return config["max_nb_lines"]
def get_syn_proba_lines(self):
config = self.params["config"]["synthetic_data"]
if config["init_proba"] == config["end_proba"]:
return config["init_proba"]
else:
nb_samples = self.training_info["step"] * self.params["batch_size"]
if config["start_scheduler_at_max_line"]:
max_step = config["num_steps_proba"]
current_step = max(
......@@ -315,25 +343,7 @@ class OCRDataset(GenericDataset):
min(nb_samples, config["num_steps_proba"]),
config["num_steps_proba"],
)
if rand() > proba:
return sample
if "mode" in config and config["mode"] == "line_hw_to_printed":
sample["img"] = self.generate_typed_text_line_image(sample["label"])
return sample
return self.generate_synthetic_page_sample()
def get_syn_max_lines(self):
config = self.params["config"]["synthetic_data"]
if config["curriculum"]:
nb_samples = self.training_info["step"] * self.params["batch_size"]
max_nb_lines = min(
config["max_nb_lines"],
(nb_samples - config["curr_start"]) // config["curr_step"] + 1,
)
return max(config["min_nb_lines"], max_nb_lines)
return config["max_nb_lines"]
return proba
def generate_synthetic_page_sample(self):
config = self.params["config"]["synthetic_data"]
......
......@@ -1160,6 +1160,9 @@ class Manager(OCRManager):
"syn_max_lines": self.dataset.train_dataset.get_syn_max_lines()
if self.params["dataset_params"]["config"]["synthetic_data"]
else 0,
"syn_prob_lines": self.dataset.train_dataset.get_syn_proba_lines()
if self.params["dataset_params"]["config"]["synthetic_data"]
else 0,
}
return values
......
......@@ -209,6 +209,7 @@ def run():
"cer",
"wer",
"syn_max_lines",
"syn_prob_lines",
], # Metrics name for training
"eval_metrics": [
"cer",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment