diff --git a/dan/manager/metrics.py b/dan/manager/metrics.py index 36d3a29f489dd30533e46cc4b732b3cb36c81d33..fdcc0ea53dd63e89b44ddaa17c295c68ea9938e5 100644 --- a/dan/manager/metrics.py +++ b/dan/manager/metrics.py @@ -31,6 +31,7 @@ class MetricManager: self.linked_metrics = { "cer": ["edit_chars", "nb_chars"], "wer": ["edit_words", "nb_words"], + "wer_no_punct": ["edit_words_no_punct", "nb_words_no_punct"], "loer": [ "edit_graph", "nb_nodes_and_edges", @@ -127,6 +128,14 @@ class MetricManager: ) if output: display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"]) + elif metric_name == "wer_no_punct": + value = np.sum(self.epoch_metrics["edit_words_no_punct"]) / np.sum( + self.epoch_metrics["nb_words_no_punct"] + ) + if output: + display_values["nb_words_no_punct"] = np.sum( + self.epoch_metrics["nb_words_no_punct"] + ) elif metric_name in [ "loss", "loss_ctc", @@ -183,6 +192,20 @@ class MetricManager: for (gt, pred) in zip(split_gt, split_pred) ] metrics["nb_words"] = [len(gt) for gt in split_gt] + elif metric_name == "wer_no_punct": + split_gt = [ + format_string_for_wer(gt, self.layout_tokens, remove_punct=True) + for gt in values["str_y"] + ] + split_pred = [ + format_string_for_wer(pred, self.layout_tokens, remove_punct=True) + for pred in values["str_x"] + ] + metrics["edit_words_no_punct"] = [ + edit_wer_from_formatted_split_text(gt, pred) + for (gt, pred) in zip(split_gt, split_pred) + ] + metrics["nb_words_no_punct"] = [len(gt) for gt in split_gt] elif metric_name in [ "loss_ctc", "loss_ce", @@ -258,22 +281,23 @@ def nb_chars_cer_from_string(gt, layout_tokens=None): return len(format_string_for_cer(gt, layout_tokens)) -def edit_wer_from_string(gt, pred, layout_tokens=None): +def edit_wer_from_string(gt, pred, layout_tokens=None, remove_punct=False): """ Format and compute edit distance between two strings at word level """ - split_gt = format_string_for_wer(gt, layout_tokens) - split_pred = format_string_for_wer(pred, layout_tokens) + split_gt = format_string_for_wer(gt, layout_tokens, remove_punct) + split_pred = format_string_for_wer(pred, layout_tokens, remove_punct) return edit_wer_from_formatted_split_text(split_gt, split_pred) -def format_string_for_wer(str, layout_tokens): +def format_string_for_wer(str, layout_tokens, remove_punct=False): """ Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space """ - str = re.sub( - r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", r" \1 ", str - ) # punctuation processed as word + if remove_punct: + str = re.sub( + r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str + ) # remove punctuation if layout_tokens is not None: str = keep_all_but_tokens( str, layout_tokens diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index 775fce6341a7e808bc2f4fef93f0897a569125b6..93b966de72d421bc6050fda2afa4eedf9ccbc179 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -369,7 +369,7 @@ class OCRDataset(GenericDataset): min(nb_samples, self.synthetic_data["num_steps_proba"]), self.synthetic_data["num_steps_proba"], ) - return proba + return proba def generate_synthetic_page_sample(self): max_nb_lines_per_page = self.get_syn_max_lines() diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index b173ba99e3d7d90988c01b7f8d185a8fa9343846..86af4d6332d986bca149a73f369bd15b0f3d6488 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -32,7 +32,7 @@ def train_and_test(rank, params): model.params["training_params"]["load_epoch"] = "best" model.load_model() - metrics = ["cer", "wer", "time"] + metrics = ["cer", "wer", "wer_no_punct", "time"] for dataset_name in params["dataset_params"]["datasets"].keys(): for set_name in ["test", "val", "train"]: model.predict( @@ -208,12 +208,14 @@ def run(): "loss_ce", "cer", "wer", + "wer_no_punct", "syn_max_lines", "syn_prob_lines", ], # Metrics name for training "eval_metrics": [ "cer", "wer", + "wer_no_punct", ], # Metrics name for evaluation on validation set during training "force_cpu": False, # True for debug purposes "max_char_prediction": 1000, # max number of token prediction diff --git a/dan/ocr/line/generate_synthetic.py b/dan/ocr/line/generate_synthetic.py index 4efcd1129eb5f7e8d5e8185a27a05eeb49039368..879d809f7374875778a315fc6365a24172d8278f 100644 --- a/dan/ocr/line/generate_synthetic.py +++ b/dan/ocr/line/generate_synthetic.py @@ -136,11 +136,17 @@ def run(): "focus_metric": "cer", # Metrics to focus on to determine best epoch "expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value "set_name_focus_metric": "{}-val".format(dataset_name), - "train_metrics": ["loss_ctc", "cer", "wer"], # Metrics name for training + "train_metrics": [ + "loss_ctc", + "cer", + "wer", + "wer_no_punct", + ], # Metrics name for training "eval_metrics": [ "loss_ctc", "cer", "wer", + "wer_no_punct", ], # Metrics name for evaluation on validation set during training "force_cpu": False, # True for debug purposes to run on cpu only }, diff --git a/dan/ocr/line/train.py b/dan/ocr/line/train.py index e76a54f3b588b83ae24fbb85cdc8446bb64a28ff..86d6009e82110b562e669bca57fcbd7ef696facf 100644 --- a/dan/ocr/line/train.py +++ b/dan/ocr/line/train.py @@ -37,6 +37,7 @@ def train_and_test(rank, params): metrics = [ "cer", "wer", + "wer_no_punct", "time", ] for dataset_name in params["dataset_params"]["datasets"].keys(): @@ -178,11 +179,17 @@ def run(): "set_name_focus_metric": "{}-val".format( dataset_name ), # Which dataset to focus on to select best weights - "train_metrics": ["loss_ctc", "cer", "wer"], # Metrics name for training + "train_metrics": [ + "loss_ctc", + "cer", + "wer", + "wer_no_punct", + ], # Metrics name for training "eval_metrics": [ "loss_ctc", "cer", "wer", + "wer_no_punct", ], # Metrics name for evaluation on validation set during training "force_cpu": False, # True for debug purposes to run on cpu only },