diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py index d941c21436ab557026bc2b2510f1450ee0139dbd..602e75ac8b354bf1be321c200af6e6c597136d79 100644 --- a/dan/ocr/manager/metrics.py +++ b/dan/ocr/manager/metrics.py @@ -1,123 +1,133 @@ # -*- coding: utf-8 -*- import re +from collections import defaultdict from operator import attrgetter from pathlib import Path -from typing import Optional +from typing import Dict, List, Optional import editdistance import numpy as np from dan.utils import parse_tokens +# Remove punctuation +REGEX_PUNCTUATION = re.compile(r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])") +# Remove consecutive linebreaks +REGEX_CONSECUTIVE_LINEBREAKS = re.compile(r"\n+") +# Remove consecutive spaces +REGEX_CONSECUTIVE_SPACES = re.compile(r" +") +# Keep only one space character +REGEX_ONLY_ONE_SPACE = re.compile(r"\s+") + class MetricManager: - def __init__(self, metric_names, dataset_name, tokens: Optional[Path]): - self.dataset_name = dataset_name + def __init__( + self, metric_names: List[str], dataset_name: str, tokens: Optional[Path] + ): + self.dataset_name: str = dataset_name + self.remove_tokens: str = None - self.layout_tokens = None if tokens: tokens = parse_tokens(tokens) - self.layout_tokens = "".join( + layout_tokens = "".join( list(map(attrgetter("start"), tokens.values())) + list(map(attrgetter("end"), tokens.values())) ) - self.metric_names = metric_names - self.epoch_metrics = None + self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])") + self.metric_names: List[str] = metric_names + self.epoch_metrics = defaultdict(list) - self.linked_metrics = { - "cer": ["edit_chars", "nb_chars"], - "wer": ["edit_words", "nb_words"], - "wer_no_punct": ["edit_words_no_punct", "nb_words_no_punct"], - } + def edit_cer_from_string(self, gt: str, pred: str): + """ + Format and compute edit distance between two strings at character level + """ + gt = self.format_string_for_cer(gt) + pred = self.format_string_for_cer(pred) + return editdistance.eval(gt, pred) + + def nb_chars_cer_from_string(self, gt: str) -> int: + """ + Compute length after formatting of ground truth string + """ + return len(self.format_string_for_cer(gt)) - self.init_metrics() + def format_string_for_wer(self, text: str, remove_punct: bool = False): + """ + Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space + """ + if remove_punct: + text = REGEX_PUNCTUATION.sub("", text) + if self.remove_tokens is not None: + text = self.remove_tokens.sub("", text) + return REGEX_ONLY_ONE_SPACE.sub(" ", text).strip().split(" ") - def init_metrics(self): + def format_string_for_cer(self, text: str): """ - Initialization of the metrics specified in metrics_name + Format string for CER computation: remove layout tokens and extra spaces """ - self.epoch_metrics = { - "nb_samples": list(), - "names": list(), - } + if self.remove_tokens is not None: + text = self.remove_tokens.sub("", text) - for metric_name in self.metric_names: - if metric_name in self.linked_metrics: - for linked_metric_name in self.linked_metrics[metric_name]: - if linked_metric_name not in self.epoch_metrics: - self.epoch_metrics[linked_metric_name] = list() - else: - self.epoch_metrics[metric_name] = list() + text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text) + return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip() def update_metrics(self, batch_metrics): """ Add batch metrics to the metrics """ for key in batch_metrics: - if key in self.epoch_metrics: - self.epoch_metrics[key] += batch_metrics[key] + self.epoch_metrics[key] += batch_metrics[key] - def get_display_values(self, output=False): + def get_display_values(self, output: bool = False): """ - format metrics values for shell display purposes + Format metrics values for shell display purposes """ metric_names = self.metric_names.copy() if output: - metric_names.extend(["nb_samples"]) + metric_names.append("nb_samples") display_values = dict() for metric_name in metric_names: - value = None - if output: - if metric_name == "nb_samples": - value = int(np.sum(self.epoch_metrics[metric_name])) - elif metric_name == "time": + match metric_name: + case "time" | "nb_samples": + if not output: + continue value = int(np.sum(self.epoch_metrics[metric_name])) - sample_time = value / np.sum(self.epoch_metrics["nb_samples"]) - display_values["sample_time"] = float(round(sample_time, 4)) - if metric_name == "cer": - value = float( - np.sum(self.epoch_metrics["edit_chars"]) - / np.sum(self.epoch_metrics["nb_chars"]) - ) - if output: - display_values["nb_chars"] = int( - np.sum(self.epoch_metrics["nb_chars"]) + if metric_name == "time": + sample_time = value / np.sum(self.epoch_metrics["nb_samples"]) + display_values["sample_time"] = float(round(sample_time, 4)) + display_values[metric_name] = value + continue + case "cer": + num_name, denom_name = "edit_chars", "nb_chars" + case "wer" | "wer_no_punct": + suffix = metric_name[3:] + num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix + case "loss" | "loss_ce": + display_values[metric_name] = round( + float( + np.average( + self.epoch_metrics[metric_name], + weights=np.array(self.epoch_metrics["nb_samples"]), + ), + ), + 4, ) - elif metric_name == "wer": - value = float( - np.sum(self.epoch_metrics["edit_words"]) - / np.sum(self.epoch_metrics["nb_words"]) - ) - if output: - display_values["nb_words"] = int( - np.sum(self.epoch_metrics["nb_words"]) - ) - elif metric_name == "wer_no_punct": - value = float( - np.sum(self.epoch_metrics["edit_words_no_punct"]) - / np.sum(self.epoch_metrics["nb_words_no_punct"]) - ) - if output: - display_values["nb_words_no_punct"] = int( - np.sum(self.epoch_metrics["nb_words_no_punct"]) - ) - elif metric_name in [ - "loss", - "loss_ce", - ]: - value = float( - np.average( - self.epoch_metrics[metric_name], - weights=np.array(self.epoch_metrics["nb_samples"]), - ) - ) - elif value is None: - continue + continue + case _: + continue + value = float( + np.sum(self.epoch_metrics[num_name]) + / np.sum(self.epoch_metrics[denom_name]) + ) + if output: + display_values[denom_name] = int(np.sum(self.epoch_metrics[denom_name])) display_values[metric_name] = round(value, 4) return display_values - def compute_metrics(self, values, metric_names): + def compute_metrics( + self, values: Dict[str, int | float], metric_names: List[str] + ) -> Dict[str, List[int | float]]: metrics = { "nb_samples": [ values["nb_samples"], @@ -125,111 +135,30 @@ class MetricManager: } if "time" in values: metrics["time"] = [values["time"]] + + gt, prediction = values["str_y"], values["str_x"] for metric_name in metric_names: - if metric_name == "cer": - metrics["edit_chars"] = [ - edit_cer_from_string(u, v, self.layout_tokens) - for u, v in zip(values["str_y"], values["str_x"]) - ] - metrics["nb_chars"] = [ - nb_chars_cer_from_string(gt, self.layout_tokens) - for gt in values["str_y"] - ] - elif metric_name == "wer": - split_gt = [ - format_string_for_wer(gt, self.layout_tokens) - for gt in values["str_y"] - ] - split_pred = [ - format_string_for_wer(pred, self.layout_tokens) - for pred in values["str_x"] - ] - metrics["edit_words"] = [ - edit_wer_from_formatted_split_text(gt, pred) - for (gt, pred) in zip(split_gt, split_pred) - ] - metrics["nb_words"] = [len(gt) for gt in split_gt] - elif metric_name == "wer_no_punct": - split_gt = [ - format_string_for_wer(gt, self.layout_tokens, remove_punct=True) - for gt in values["str_y"] - ] - split_pred = [ - format_string_for_wer(pred, self.layout_tokens, remove_punct=True) - for pred in values["str_x"] - ] - metrics["edit_words_no_punct"] = [ - edit_wer_from_formatted_split_text(gt, pred) - for (gt, pred) in zip(split_gt, split_pred) - ] - metrics["nb_words_no_punct"] = [len(gt) for gt in split_gt] - elif metric_name in [ - "loss_ce", - "loss", - ]: - metrics[metric_name] = [ - values[metric_name], - ] + match metric_name: + case "cer": + metrics["edit_chars"] = list( + map(self.edit_cer_from_string, gt, prediction) + ) + metrics["nb_chars"] = list(map(self.nb_chars_cer_from_string, gt)) + case "wer" | "wer_no_punct": + suffix = metric_name[3:] + split_gt = list(map(self.format_string_for_wer, gt, [bool(suffix)])) + split_pred = list( + map(self.format_string_for_wer, prediction, [bool(suffix)]) + ) + metrics["edit_words" + suffix] = list( + map(editdistance.eval, split_gt, split_pred) + ) + metrics["nb_words" + suffix] = list(map(len, split_gt)) + case "loss" | "loss_ce": + metrics[metric_name] = [ + values[metric_name], + ] return metrics - def get(self, name): + def get(self, name: str): return self.epoch_metrics[name] - - -def keep_all_but_ner_tokens(str, tokens): - """ - Remove all ner tokens from string - """ - return re.sub("([" + tokens + "])", "", str) - - -def edit_cer_from_string(gt, pred, layout_tokens=None): - """ - Format and compute edit distance between two strings at character level - """ - gt = format_string_for_cer(gt, layout_tokens) - pred = format_string_for_cer(pred, layout_tokens) - return editdistance.eval(gt, pred) - - -def nb_chars_cer_from_string(gt, layout_tokens=None): - """ - Compute length after formatting of ground truth string - """ - return len(format_string_for_cer(gt, layout_tokens)) - - -def format_string_for_wer(str, layout_tokens, remove_punct=False): - """ - Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space - """ - if remove_punct: - str = re.sub( - r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str - ) # remove punctuation - if layout_tokens is not None: - str = keep_all_but_ner_tokens( - str, layout_tokens - ) # remove layout tokens from metric - str = re.sub("([ \n])+", " ", str).strip() # keep only one space character - return str.split(" ") - - -def format_string_for_cer(str, layout_tokens): - """ - Format string for CER computation: remove layout tokens and extra spaces - """ - if layout_tokens is not None: - str = keep_all_but_ner_tokens( - str, layout_tokens - ) # remove layout tokens from metric - str = re.sub("([\n])+", "\n", str) # remove consecutive line breaks - str = re.sub("([ ])+", " ", str).strip() # remove consecutive spaces - return str - - -def edit_wer_from_formatted_split_text(gt, pred): - """ - Compute edit distance at word level from formatted string as list - """ - return editdistance.eval(gt, pred)