Skip to content
Snippets Groups Projects
Commit a47466ef authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Manon Blanco
Browse files

Metrics manager refactoring

parent 0f154c32
No related branches found
No related tags found
1 merge request!249Metrics manager refactoring
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import re import re
from collections import defaultdict
from operator import attrgetter from operator import attrgetter
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Dict, List, Optional
import editdistance import editdistance
import numpy as np import numpy as np
from dan.utils import parse_tokens from dan.utils import parse_tokens
# Remove punctuation
REGEX_PUNCTUATION = re.compile(r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])")
# Remove consecutive linebreaks
REGEX_CONSECUTIVE_LINEBREAKS = re.compile(r"\n+")
# Remove consecutive spaces
REGEX_CONSECUTIVE_SPACES = re.compile(r" +")
# Keep only one space character
REGEX_ONLY_ONE_SPACE = re.compile(r"\s+")
class MetricManager: class MetricManager:
def __init__(self, metric_names, dataset_name, tokens: Optional[Path]): def __init__(
self.dataset_name = dataset_name self, metric_names: List[str], dataset_name: str, tokens: Optional[Path]
):
self.dataset_name: str = dataset_name
self.remove_tokens: str = None
self.layout_tokens = None
if tokens: if tokens:
tokens = parse_tokens(tokens) tokens = parse_tokens(tokens)
self.layout_tokens = "".join( layout_tokens = "".join(
list(map(attrgetter("start"), tokens.values())) list(map(attrgetter("start"), tokens.values()))
+ list(map(attrgetter("end"), tokens.values())) + list(map(attrgetter("end"), tokens.values()))
) )
self.metric_names = metric_names self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])")
self.epoch_metrics = None self.metric_names: List[str] = metric_names
self.epoch_metrics = defaultdict(list)
self.linked_metrics = { def edit_cer_from_string(self, gt: str, pred: str):
"cer": ["edit_chars", "nb_chars"], """
"wer": ["edit_words", "nb_words"], Format and compute edit distance between two strings at character level
"wer_no_punct": ["edit_words_no_punct", "nb_words_no_punct"], """
} gt = self.format_string_for_cer(gt)
pred = self.format_string_for_cer(pred)
return editdistance.eval(gt, pred)
def nb_chars_cer_from_string(self, gt: str) -> int:
"""
Compute length after formatting of ground truth string
"""
return len(self.format_string_for_cer(gt))
self.init_metrics() def format_string_for_wer(self, text: str, remove_punct: bool = False):
"""
Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
"""
if remove_punct:
text = REGEX_PUNCTUATION.sub("", text)
if self.remove_tokens is not None:
text = self.remove_tokens.sub("", text)
return REGEX_ONLY_ONE_SPACE.sub(" ", text).strip().split(" ")
def init_metrics(self): def format_string_for_cer(self, text: str):
""" """
Initialization of the metrics specified in metrics_name Format string for CER computation: remove layout tokens and extra spaces
""" """
self.epoch_metrics = { if self.remove_tokens is not None:
"nb_samples": list(), text = self.remove_tokens.sub("", text)
"names": list(),
}
for metric_name in self.metric_names: text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text)
if metric_name in self.linked_metrics: return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip()
for linked_metric_name in self.linked_metrics[metric_name]:
if linked_metric_name not in self.epoch_metrics:
self.epoch_metrics[linked_metric_name] = list()
else:
self.epoch_metrics[metric_name] = list()
def update_metrics(self, batch_metrics): def update_metrics(self, batch_metrics):
""" """
Add batch metrics to the metrics Add batch metrics to the metrics
""" """
for key in batch_metrics: for key in batch_metrics:
if key in self.epoch_metrics: self.epoch_metrics[key] += batch_metrics[key]
self.epoch_metrics[key] += batch_metrics[key]
def get_display_values(self, output=False): def get_display_values(self, output: bool = False):
""" """
format metrics values for shell display purposes Format metrics values for shell display purposes
""" """
metric_names = self.metric_names.copy() metric_names = self.metric_names.copy()
if output: if output:
metric_names.extend(["nb_samples"]) metric_names.append("nb_samples")
display_values = dict() display_values = dict()
for metric_name in metric_names: for metric_name in metric_names:
value = None match metric_name:
if output: case "time" | "nb_samples":
if metric_name == "nb_samples": if not output:
value = int(np.sum(self.epoch_metrics[metric_name])) continue
elif metric_name == "time":
value = int(np.sum(self.epoch_metrics[metric_name])) value = int(np.sum(self.epoch_metrics[metric_name]))
sample_time = value / np.sum(self.epoch_metrics["nb_samples"]) if metric_name == "time":
display_values["sample_time"] = float(round(sample_time, 4)) sample_time = value / np.sum(self.epoch_metrics["nb_samples"])
if metric_name == "cer": display_values["sample_time"] = float(round(sample_time, 4))
value = float( display_values[metric_name] = value
np.sum(self.epoch_metrics["edit_chars"]) continue
/ np.sum(self.epoch_metrics["nb_chars"]) case "cer":
) num_name, denom_name = "edit_chars", "nb_chars"
if output: case "wer" | "wer_no_punct":
display_values["nb_chars"] = int( suffix = metric_name[3:]
np.sum(self.epoch_metrics["nb_chars"]) num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix
case "loss" | "loss_ce":
display_values[metric_name] = round(
float(
np.average(
self.epoch_metrics[metric_name],
weights=np.array(self.epoch_metrics["nb_samples"]),
),
),
4,
) )
elif metric_name == "wer": continue
value = float( case _:
np.sum(self.epoch_metrics["edit_words"]) continue
/ np.sum(self.epoch_metrics["nb_words"])
)
if output:
display_values["nb_words"] = int(
np.sum(self.epoch_metrics["nb_words"])
)
elif metric_name == "wer_no_punct":
value = float(
np.sum(self.epoch_metrics["edit_words_no_punct"])
/ np.sum(self.epoch_metrics["nb_words_no_punct"])
)
if output:
display_values["nb_words_no_punct"] = int(
np.sum(self.epoch_metrics["nb_words_no_punct"])
)
elif metric_name in [
"loss",
"loss_ce",
]:
value = float(
np.average(
self.epoch_metrics[metric_name],
weights=np.array(self.epoch_metrics["nb_samples"]),
)
)
elif value is None:
continue
value = float(
np.sum(self.epoch_metrics[num_name])
/ np.sum(self.epoch_metrics[denom_name])
)
if output:
display_values[denom_name] = int(np.sum(self.epoch_metrics[denom_name]))
display_values[metric_name] = round(value, 4) display_values[metric_name] = round(value, 4)
return display_values return display_values
def compute_metrics(self, values, metric_names): def compute_metrics(
self, values: Dict[str, int | float], metric_names: List[str]
) -> Dict[str, List[int | float]]:
metrics = { metrics = {
"nb_samples": [ "nb_samples": [
values["nb_samples"], values["nb_samples"],
...@@ -125,111 +135,30 @@ class MetricManager: ...@@ -125,111 +135,30 @@ class MetricManager:
} }
if "time" in values: if "time" in values:
metrics["time"] = [values["time"]] metrics["time"] = [values["time"]]
gt, prediction = values["str_y"], values["str_x"]
for metric_name in metric_names: for metric_name in metric_names:
if metric_name == "cer": match metric_name:
metrics["edit_chars"] = [ case "cer":
edit_cer_from_string(u, v, self.layout_tokens) metrics["edit_chars"] = list(
for u, v in zip(values["str_y"], values["str_x"]) map(self.edit_cer_from_string, gt, prediction)
] )
metrics["nb_chars"] = [ metrics["nb_chars"] = list(map(self.nb_chars_cer_from_string, gt))
nb_chars_cer_from_string(gt, self.layout_tokens) case "wer" | "wer_no_punct":
for gt in values["str_y"] suffix = metric_name[3:]
] split_gt = list(map(self.format_string_for_wer, gt, [bool(suffix)]))
elif metric_name == "wer": split_pred = list(
split_gt = [ map(self.format_string_for_wer, prediction, [bool(suffix)])
format_string_for_wer(gt, self.layout_tokens) )
for gt in values["str_y"] metrics["edit_words" + suffix] = list(
] map(editdistance.eval, split_gt, split_pred)
split_pred = [ )
format_string_for_wer(pred, self.layout_tokens) metrics["nb_words" + suffix] = list(map(len, split_gt))
for pred in values["str_x"] case "loss" | "loss_ce":
] metrics[metric_name] = [
metrics["edit_words"] = [ values[metric_name],
edit_wer_from_formatted_split_text(gt, pred) ]
for (gt, pred) in zip(split_gt, split_pred)
]
metrics["nb_words"] = [len(gt) for gt in split_gt]
elif metric_name == "wer_no_punct":
split_gt = [
format_string_for_wer(gt, self.layout_tokens, remove_punct=True)
for gt in values["str_y"]
]
split_pred = [
format_string_for_wer(pred, self.layout_tokens, remove_punct=True)
for pred in values["str_x"]
]
metrics["edit_words_no_punct"] = [
edit_wer_from_formatted_split_text(gt, pred)
for (gt, pred) in zip(split_gt, split_pred)
]
metrics["nb_words_no_punct"] = [len(gt) for gt in split_gt]
elif metric_name in [
"loss_ce",
"loss",
]:
metrics[metric_name] = [
values[metric_name],
]
return metrics return metrics
def get(self, name): def get(self, name: str):
return self.epoch_metrics[name] return self.epoch_metrics[name]
def keep_all_but_ner_tokens(str, tokens):
"""
Remove all ner tokens from string
"""
return re.sub("([" + tokens + "])", "", str)
def edit_cer_from_string(gt, pred, layout_tokens=None):
"""
Format and compute edit distance between two strings at character level
"""
gt = format_string_for_cer(gt, layout_tokens)
pred = format_string_for_cer(pred, layout_tokens)
return editdistance.eval(gt, pred)
def nb_chars_cer_from_string(gt, layout_tokens=None):
"""
Compute length after formatting of ground truth string
"""
return len(format_string_for_cer(gt, layout_tokens))
def format_string_for_wer(str, layout_tokens, remove_punct=False):
"""
Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
"""
if remove_punct:
str = re.sub(
r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str
) # remove punctuation
if layout_tokens is not None:
str = keep_all_but_ner_tokens(
str, layout_tokens
) # remove layout tokens from metric
str = re.sub("([ \n])+", " ", str).strip() # keep only one space character
return str.split(" ")
def format_string_for_cer(str, layout_tokens):
"""
Format string for CER computation: remove layout tokens and extra spaces
"""
if layout_tokens is not None:
str = keep_all_but_ner_tokens(
str, layout_tokens
) # remove layout tokens from metric
str = re.sub("([\n])+", "\n", str) # remove consecutive line breaks
str = re.sub("([ ])+", " ", str).strip() # remove consecutive spaces
return str
def edit_wer_from_formatted_split_text(gt, pred):
"""
Compute edit distance at word level from formatted string as list
"""
return editdistance.eval(gt, pred)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment