Skip to content
Snippets Groups Projects
Commit a47466ef authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Manon Blanco
Browse files

Metrics manager refactoring

parent 0f154c32
No related branches found
No related tags found
1 merge request!249Metrics manager refactoring
# -*- coding: utf-8 -*-
import re
from collections import defaultdict
from operator import attrgetter
from pathlib import Path
from typing import Optional
from typing import Dict, List, Optional
import editdistance
import numpy as np
from dan.utils import parse_tokens
# Remove punctuation
REGEX_PUNCTUATION = re.compile(r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])")
# Remove consecutive linebreaks
REGEX_CONSECUTIVE_LINEBREAKS = re.compile(r"\n+")
# Remove consecutive spaces
REGEX_CONSECUTIVE_SPACES = re.compile(r" +")
# Keep only one space character
REGEX_ONLY_ONE_SPACE = re.compile(r"\s+")
class MetricManager:
def __init__(self, metric_names, dataset_name, tokens: Optional[Path]):
self.dataset_name = dataset_name
def __init__(
self, metric_names: List[str], dataset_name: str, tokens: Optional[Path]
):
self.dataset_name: str = dataset_name
self.remove_tokens: str = None
self.layout_tokens = None
if tokens:
tokens = parse_tokens(tokens)
self.layout_tokens = "".join(
layout_tokens = "".join(
list(map(attrgetter("start"), tokens.values()))
+ list(map(attrgetter("end"), tokens.values()))
)
self.metric_names = metric_names
self.epoch_metrics = None
self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])")
self.metric_names: List[str] = metric_names
self.epoch_metrics = defaultdict(list)
self.linked_metrics = {
"cer": ["edit_chars", "nb_chars"],
"wer": ["edit_words", "nb_words"],
"wer_no_punct": ["edit_words_no_punct", "nb_words_no_punct"],
}
def edit_cer_from_string(self, gt: str, pred: str):
"""
Format and compute edit distance between two strings at character level
"""
gt = self.format_string_for_cer(gt)
pred = self.format_string_for_cer(pred)
return editdistance.eval(gt, pred)
def nb_chars_cer_from_string(self, gt: str) -> int:
"""
Compute length after formatting of ground truth string
"""
return len(self.format_string_for_cer(gt))
self.init_metrics()
def format_string_for_wer(self, text: str, remove_punct: bool = False):
"""
Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
"""
if remove_punct:
text = REGEX_PUNCTUATION.sub("", text)
if self.remove_tokens is not None:
text = self.remove_tokens.sub("", text)
return REGEX_ONLY_ONE_SPACE.sub(" ", text).strip().split(" ")
def init_metrics(self):
def format_string_for_cer(self, text: str):
"""
Initialization of the metrics specified in metrics_name
Format string for CER computation: remove layout tokens and extra spaces
"""
self.epoch_metrics = {
"nb_samples": list(),
"names": list(),
}
if self.remove_tokens is not None:
text = self.remove_tokens.sub("", text)
for metric_name in self.metric_names:
if metric_name in self.linked_metrics:
for linked_metric_name in self.linked_metrics[metric_name]:
if linked_metric_name not in self.epoch_metrics:
self.epoch_metrics[linked_metric_name] = list()
else:
self.epoch_metrics[metric_name] = list()
text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text)
return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip()
def update_metrics(self, batch_metrics):
"""
Add batch metrics to the metrics
"""
for key in batch_metrics:
if key in self.epoch_metrics:
self.epoch_metrics[key] += batch_metrics[key]
self.epoch_metrics[key] += batch_metrics[key]
def get_display_values(self, output=False):
def get_display_values(self, output: bool = False):
"""
format metrics values for shell display purposes
Format metrics values for shell display purposes
"""
metric_names = self.metric_names.copy()
if output:
metric_names.extend(["nb_samples"])
metric_names.append("nb_samples")
display_values = dict()
for metric_name in metric_names:
value = None
if output:
if metric_name == "nb_samples":
value = int(np.sum(self.epoch_metrics[metric_name]))
elif metric_name == "time":
match metric_name:
case "time" | "nb_samples":
if not output:
continue
value = int(np.sum(self.epoch_metrics[metric_name]))
sample_time = value / np.sum(self.epoch_metrics["nb_samples"])
display_values["sample_time"] = float(round(sample_time, 4))
if metric_name == "cer":
value = float(
np.sum(self.epoch_metrics["edit_chars"])
/ np.sum(self.epoch_metrics["nb_chars"])
)
if output:
display_values["nb_chars"] = int(
np.sum(self.epoch_metrics["nb_chars"])
if metric_name == "time":
sample_time = value / np.sum(self.epoch_metrics["nb_samples"])
display_values["sample_time"] = float(round(sample_time, 4))
display_values[metric_name] = value
continue
case "cer":
num_name, denom_name = "edit_chars", "nb_chars"
case "wer" | "wer_no_punct":
suffix = metric_name[3:]
num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix
case "loss" | "loss_ce":
display_values[metric_name] = round(
float(
np.average(
self.epoch_metrics[metric_name],
weights=np.array(self.epoch_metrics["nb_samples"]),
),
),
4,
)
elif metric_name == "wer":
value = float(
np.sum(self.epoch_metrics["edit_words"])
/ np.sum(self.epoch_metrics["nb_words"])
)
if output:
display_values["nb_words"] = int(
np.sum(self.epoch_metrics["nb_words"])
)
elif metric_name == "wer_no_punct":
value = float(
np.sum(self.epoch_metrics["edit_words_no_punct"])
/ np.sum(self.epoch_metrics["nb_words_no_punct"])
)
if output:
display_values["nb_words_no_punct"] = int(
np.sum(self.epoch_metrics["nb_words_no_punct"])
)
elif metric_name in [
"loss",
"loss_ce",
]:
value = float(
np.average(
self.epoch_metrics[metric_name],
weights=np.array(self.epoch_metrics["nb_samples"]),
)
)
elif value is None:
continue
continue
case _:
continue
value = float(
np.sum(self.epoch_metrics[num_name])
/ np.sum(self.epoch_metrics[denom_name])
)
if output:
display_values[denom_name] = int(np.sum(self.epoch_metrics[denom_name]))
display_values[metric_name] = round(value, 4)
return display_values
def compute_metrics(self, values, metric_names):
def compute_metrics(
self, values: Dict[str, int | float], metric_names: List[str]
) -> Dict[str, List[int | float]]:
metrics = {
"nb_samples": [
values["nb_samples"],
......@@ -125,111 +135,30 @@ class MetricManager:
}
if "time" in values:
metrics["time"] = [values["time"]]
gt, prediction = values["str_y"], values["str_x"]
for metric_name in metric_names:
if metric_name == "cer":
metrics["edit_chars"] = [
edit_cer_from_string(u, v, self.layout_tokens)
for u, v in zip(values["str_y"], values["str_x"])
]
metrics["nb_chars"] = [
nb_chars_cer_from_string(gt, self.layout_tokens)
for gt in values["str_y"]
]
elif metric_name == "wer":
split_gt = [
format_string_for_wer(gt, self.layout_tokens)
for gt in values["str_y"]
]
split_pred = [
format_string_for_wer(pred, self.layout_tokens)
for pred in values["str_x"]
]
metrics["edit_words"] = [
edit_wer_from_formatted_split_text(gt, pred)
for (gt, pred) in zip(split_gt, split_pred)
]
metrics["nb_words"] = [len(gt) for gt in split_gt]
elif metric_name == "wer_no_punct":
split_gt = [
format_string_for_wer(gt, self.layout_tokens, remove_punct=True)
for gt in values["str_y"]
]
split_pred = [
format_string_for_wer(pred, self.layout_tokens, remove_punct=True)
for pred in values["str_x"]
]
metrics["edit_words_no_punct"] = [
edit_wer_from_formatted_split_text(gt, pred)
for (gt, pred) in zip(split_gt, split_pred)
]
metrics["nb_words_no_punct"] = [len(gt) for gt in split_gt]
elif metric_name in [
"loss_ce",
"loss",
]:
metrics[metric_name] = [
values[metric_name],
]
match metric_name:
case "cer":
metrics["edit_chars"] = list(
map(self.edit_cer_from_string, gt, prediction)
)
metrics["nb_chars"] = list(map(self.nb_chars_cer_from_string, gt))
case "wer" | "wer_no_punct":
suffix = metric_name[3:]
split_gt = list(map(self.format_string_for_wer, gt, [bool(suffix)]))
split_pred = list(
map(self.format_string_for_wer, prediction, [bool(suffix)])
)
metrics["edit_words" + suffix] = list(
map(editdistance.eval, split_gt, split_pred)
)
metrics["nb_words" + suffix] = list(map(len, split_gt))
case "loss" | "loss_ce":
metrics[metric_name] = [
values[metric_name],
]
return metrics
def get(self, name):
def get(self, name: str):
return self.epoch_metrics[name]
def keep_all_but_ner_tokens(str, tokens):
"""
Remove all ner tokens from string
"""
return re.sub("([" + tokens + "])", "", str)
def edit_cer_from_string(gt, pred, layout_tokens=None):
"""
Format and compute edit distance between two strings at character level
"""
gt = format_string_for_cer(gt, layout_tokens)
pred = format_string_for_cer(pred, layout_tokens)
return editdistance.eval(gt, pred)
def nb_chars_cer_from_string(gt, layout_tokens=None):
"""
Compute length after formatting of ground truth string
"""
return len(format_string_for_cer(gt, layout_tokens))
def format_string_for_wer(str, layout_tokens, remove_punct=False):
"""
Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
"""
if remove_punct:
str = re.sub(
r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str
) # remove punctuation
if layout_tokens is not None:
str = keep_all_but_ner_tokens(
str, layout_tokens
) # remove layout tokens from metric
str = re.sub("([ \n])+", " ", str).strip() # keep only one space character
return str.split(" ")
def format_string_for_cer(str, layout_tokens):
"""
Format string for CER computation: remove layout tokens and extra spaces
"""
if layout_tokens is not None:
str = keep_all_but_ner_tokens(
str, layout_tokens
) # remove layout tokens from metric
str = re.sub("([\n])+", "\n", str) # remove consecutive line breaks
str = re.sub("([ ])+", " ", str).strip() # remove consecutive spaces
return str
def edit_wer_from_formatted_split_text(gt, pred):
"""
Compute edit distance at word level from formatted string as list
"""
return editdistance.eval(gt, pred)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment