Skip to content
Snippets Groups Projects

Normalize wer computation

Merged Solene Tarride requested to merge normalize-wer-computation into main
4 files
+ 41
9
Compare changes
  • Side-by-side
  • Inline
Files
4
+ 31
7
@@ -31,6 +31,7 @@ class MetricManager:
self.linked_metrics = {
"cer": ["edit_chars", "nb_chars"],
"wer": ["edit_words", "nb_words"],
"wer_no_punct": ["edit_words_no_punct", "nb_words_no_punct"],
"loer": [
"edit_graph",
"nb_nodes_and_edges",
@@ -127,6 +128,14 @@ class MetricManager:
)
if output:
display_values["nb_words"] = np.sum(self.epoch_metrics["nb_words"])
elif metric_name == "wer_no_punct":
value = np.sum(self.epoch_metrics["edit_words_no_punct"]) / np.sum(
self.epoch_metrics["nb_words_no_punct"]
)
if output:
display_values["nb_words_no_punct"] = np.sum(
self.epoch_metrics["nb_words_no_punct"]
)
elif metric_name in [
"loss",
"loss_ctc",
@@ -183,6 +192,20 @@ class MetricManager:
for (gt, pred) in zip(split_gt, split_pred)
]
metrics["nb_words"] = [len(gt) for gt in split_gt]
elif metric_name == "wer_no_punct":
split_gt = [
format_string_for_wer(gt, self.layout_tokens, remove_punct=True)
for gt in values["str_y"]
]
split_pred = [
format_string_for_wer(pred, self.layout_tokens, remove_punct=True)
for pred in values["str_x"]
]
metrics["edit_words_no_punct"] = [
edit_wer_from_formatted_split_text(gt, pred)
for (gt, pred) in zip(split_gt, split_pred)
]
metrics["nb_words_no_punct"] = [len(gt) for gt in split_gt]
elif metric_name in [
"loss_ctc",
"loss_ce",
@@ -258,22 +281,23 @@ def nb_chars_cer_from_string(gt, layout_tokens=None):
return len(format_string_for_cer(gt, layout_tokens))
def edit_wer_from_string(gt, pred, layout_tokens=None):
def edit_wer_from_string(gt, pred, layout_tokens=None, remove_punct=False):
"""
Format and compute edit distance between two strings at word level
"""
split_gt = format_string_for_wer(gt, layout_tokens)
split_pred = format_string_for_wer(pred, layout_tokens)
split_gt = format_string_for_wer(gt, layout_tokens, remove_punct)
split_pred = format_string_for_wer(pred, layout_tokens, remove_punct)
return edit_wer_from_formatted_split_text(split_gt, split_pred)
def format_string_for_wer(str, layout_tokens):
def format_string_for_wer(str, layout_tokens, remove_punct=False):
"""
Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
"""
str = re.sub(
r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", r" \1 ", str
) # punctuation processed as word
if remove_punct:
str = re.sub(
r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str
) # remove punctuation
if layout_tokens is not None:
str = keep_all_but_tokens(
str, layout_tokens
Loading