Skip to content
Snippets Groups Projects
Commit 6cdfc750 authored by Solene Tarride's avatar Solene Tarride Committed by Yoann Schneider
Browse files

Normalize wer computation

parent 80771a36
No related branches found
No related tags found
1 merge request!30Normalize wer computation
......@@ -31,6 +31,7 @@ class MetricManager:
self.linked_metrics = {
"cer": ["edit_chars", "nb_chars"],
"wer": ["edit_words", "nb_words"],
"wer_no_punct": ["edit_words_no_punct", "nb_words_no_punct"],
"loer": [
"edit_graph",
"nb_nodes_and_edges",
......@@ -127,6 +128,14 @@ class MetricManager:
)
if output:
display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"])
elif metric_name == "wer_no_punct":
value = np.sum(self.epoch_metrics["edit_words_no_punct"]) / np.sum(
self.epoch_metrics["nb_words_no_punct"]
)
if output:
display_values["nb_words_no_punct"] = np.sum(
self.epoch_metrics["nb_words_no_punct"]
)
elif metric_name in [
"loss",
"loss_ctc",
......@@ -183,6 +192,20 @@ class MetricManager:
for (gt, pred) in zip(split_gt, split_pred)
]
metrics["nb_words"] = [len(gt) for gt in split_gt]
elif metric_name == "wer_no_punct":
split_gt = [
format_string_for_wer(gt, self.layout_tokens, remove_punct=True)
for gt in values["str_y"]
]
split_pred = [
format_string_for_wer(pred, self.layout_tokens, remove_punct=True)
for pred in values["str_x"]
]
metrics["edit_words_no_punct"] = [
edit_wer_from_formatted_split_text(gt, pred)
for (gt, pred) in zip(split_gt, split_pred)
]
metrics["nb_words_no_punct"] = [len(gt) for gt in split_gt]
elif metric_name in [
"loss_ctc",
"loss_ce",
......@@ -258,22 +281,23 @@ def nb_chars_cer_from_string(gt, layout_tokens=None):
return len(format_string_for_cer(gt, layout_tokens))
def edit_wer_from_string(gt, pred, layout_tokens=None):
def edit_wer_from_string(gt, pred, layout_tokens=None, remove_punct=False):
"""
Format and compute edit distance between two strings at word level
"""
split_gt = format_string_for_wer(gt, layout_tokens)
split_pred = format_string_for_wer(pred, layout_tokens)
split_gt = format_string_for_wer(gt, layout_tokens, remove_punct)
split_pred = format_string_for_wer(pred, layout_tokens, remove_punct)
return edit_wer_from_formatted_split_text(split_gt, split_pred)
def format_string_for_wer(str, layout_tokens):
def format_string_for_wer(str, layout_tokens, remove_punct=False):
"""
Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
"""
str = re.sub(
r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", r" \1 ", str
) # punctuation processed as word
if remove_punct:
str = re.sub(
r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str
) # remove punctuation
if layout_tokens is not None:
str = keep_all_but_tokens(
str, layout_tokens
......
......@@ -369,7 +369,7 @@ class OCRDataset(GenericDataset):
min(nb_samples, self.synthetic_data["num_steps_proba"]),
self.synthetic_data["num_steps_proba"],
)
return proba
return proba
def generate_synthetic_page_sample(self):
max_nb_lines_per_page = self.get_syn_max_lines()
......
......@@ -32,7 +32,7 @@ def train_and_test(rank, params):
model.params["training_params"]["load_epoch"] = "best"
model.load_model()
metrics = ["cer", "wer", "time"]
metrics = ["cer", "wer", "wer_no_punct", "time"]
for dataset_name in params["dataset_params"]["datasets"].keys():
for set_name in ["test", "val", "train"]:
model.predict(
......@@ -208,12 +208,14 @@ def run():
"loss_ce",
"cer",
"wer",
"wer_no_punct",
"syn_max_lines",
"syn_prob_lines",
], # Metrics name for training
"eval_metrics": [
"cer",
"wer",
"wer_no_punct",
], # Metrics name for evaluation on validation set during training
"force_cpu": False, # True for debug purposes
"max_char_prediction": 1000, # max number of token prediction
......
......@@ -136,11 +136,17 @@ def run():
"focus_metric": "cer", # Metrics to focus on to determine best epoch
"expected_metric_value": "low", # ["high", "low"] What is best for the focus metric value
"set_name_focus_metric": "{}-val".format(dataset_name),
"train_metrics": ["loss_ctc", "cer", "wer"], # Metrics name for training
"train_metrics": [
"loss_ctc",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for training
"eval_metrics": [
"loss_ctc",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for evaluation on validation set during training
"force_cpu": False, # True for debug purposes to run on cpu only
},
......
......@@ -37,6 +37,7 @@ def train_and_test(rank, params):
metrics = [
"cer",
"wer",
"wer_no_punct",
"time",
]
for dataset_name in params["dataset_params"]["datasets"].keys():
......@@ -178,11 +179,17 @@ def run():
"set_name_focus_metric": "{}-val".format(
dataset_name
), # Which dataset to focus on to select best weights
"train_metrics": ["loss_ctc", "cer", "wer"], # Metrics name for training
"train_metrics": [
"loss_ctc",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for training
"eval_metrics": [
"loss_ctc",
"cer",
"wer",
"wer_no_punct",
], # Metrics name for evaluation on validation set during training
"force_cpu": False, # True for debug purposes to run on cpu only
},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment