diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py
index d941c21436ab557026bc2b2510f1450ee0139dbd..602e75ac8b354bf1be321c200af6e6c597136d79 100644
--- a/dan/ocr/manager/metrics.py
+++ b/dan/ocr/manager/metrics.py
@@ -1,123 +1,133 @@
 # -*- coding: utf-8 -*-
 import re
+from collections import defaultdict
 from operator import attrgetter
 from pathlib import Path
-from typing import Optional
+from typing import Dict, List, Optional
 
 import editdistance
 import numpy as np
 
 from dan.utils import parse_tokens
 
+# Remove punctuation
+REGEX_PUNCTUATION = re.compile(r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])")
+# Remove consecutive linebreaks
+REGEX_CONSECUTIVE_LINEBREAKS = re.compile(r"\n+")
+# Remove consecutive spaces
+REGEX_CONSECUTIVE_SPACES = re.compile(r" +")
+# Keep only one space character
+REGEX_ONLY_ONE_SPACE = re.compile(r"\s+")
+
 
 class MetricManager:
-    def __init__(self, metric_names, dataset_name, tokens: Optional[Path]):
-        self.dataset_name = dataset_name
+    def __init__(
+        self, metric_names: List[str], dataset_name: str, tokens: Optional[Path]
+    ):
+        self.dataset_name: str = dataset_name
+        self.remove_tokens: str = None
 
-        self.layout_tokens = None
         if tokens:
             tokens = parse_tokens(tokens)
-            self.layout_tokens = "".join(
+            layout_tokens = "".join(
                 list(map(attrgetter("start"), tokens.values()))
                 + list(map(attrgetter("end"), tokens.values()))
             )
-        self.metric_names = metric_names
-        self.epoch_metrics = None
+            self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])")
+        self.metric_names: List[str] = metric_names
+        self.epoch_metrics = defaultdict(list)
 
-        self.linked_metrics = {
-            "cer": ["edit_chars", "nb_chars"],
-            "wer": ["edit_words", "nb_words"],
-            "wer_no_punct": ["edit_words_no_punct", "nb_words_no_punct"],
-        }
+    def edit_cer_from_string(self, gt: str, pred: str):
+        """
+        Format and compute edit distance between two strings at character level
+        """
+        gt = self.format_string_for_cer(gt)
+        pred = self.format_string_for_cer(pred)
+        return editdistance.eval(gt, pred)
+
+    def nb_chars_cer_from_string(self, gt: str) -> int:
+        """
+        Compute length after formatting of ground truth string
+        """
+        return len(self.format_string_for_cer(gt))
 
-        self.init_metrics()
+    def format_string_for_wer(self, text: str, remove_punct: bool = False):
+        """
+        Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
+        """
+        if remove_punct:
+            text = REGEX_PUNCTUATION.sub("", text)
+        if self.remove_tokens is not None:
+            text = self.remove_tokens.sub("", text)
+        return REGEX_ONLY_ONE_SPACE.sub(" ", text).strip().split(" ")
 
-    def init_metrics(self):
+    def format_string_for_cer(self, text: str):
         """
-        Initialization of the metrics specified in metrics_name
+        Format string for CER computation: remove layout tokens and extra spaces
         """
-        self.epoch_metrics = {
-            "nb_samples": list(),
-            "names": list(),
-        }
+        if self.remove_tokens is not None:
+            text = self.remove_tokens.sub("", text)
 
-        for metric_name in self.metric_names:
-            if metric_name in self.linked_metrics:
-                for linked_metric_name in self.linked_metrics[metric_name]:
-                    if linked_metric_name not in self.epoch_metrics:
-                        self.epoch_metrics[linked_metric_name] = list()
-            else:
-                self.epoch_metrics[metric_name] = list()
+        text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text)
+        return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip()
 
     def update_metrics(self, batch_metrics):
         """
         Add batch metrics to the metrics
         """
         for key in batch_metrics:
-            if key in self.epoch_metrics:
-                self.epoch_metrics[key] += batch_metrics[key]
+            self.epoch_metrics[key] += batch_metrics[key]
 
-    def get_display_values(self, output=False):
+    def get_display_values(self, output: bool = False):
         """
-        format metrics values for shell display purposes
+        Format metrics values for shell display purposes
         """
         metric_names = self.metric_names.copy()
         if output:
-            metric_names.extend(["nb_samples"])
+            metric_names.append("nb_samples")
         display_values = dict()
         for metric_name in metric_names:
-            value = None
-            if output:
-                if metric_name == "nb_samples":
-                    value = int(np.sum(self.epoch_metrics[metric_name]))
-                elif metric_name == "time":
+            match metric_name:
+                case "time" | "nb_samples":
+                    if not output:
+                        continue
                     value = int(np.sum(self.epoch_metrics[metric_name]))
-                    sample_time = value / np.sum(self.epoch_metrics["nb_samples"])
-                    display_values["sample_time"] = float(round(sample_time, 4))
-            if metric_name == "cer":
-                value = float(
-                    np.sum(self.epoch_metrics["edit_chars"])
-                    / np.sum(self.epoch_metrics["nb_chars"])
-                )
-                if output:
-                    display_values["nb_chars"] = int(
-                        np.sum(self.epoch_metrics["nb_chars"])
+                    if metric_name == "time":
+                        sample_time = value / np.sum(self.epoch_metrics["nb_samples"])
+                        display_values["sample_time"] = float(round(sample_time, 4))
+                    display_values[metric_name] = value
+                    continue
+                case "cer":
+                    num_name, denom_name = "edit_chars", "nb_chars"
+                case "wer" | "wer_no_punct":
+                    suffix = metric_name[3:]
+                    num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix
+                case "loss" | "loss_ce":
+                    display_values[metric_name] = round(
+                        float(
+                            np.average(
+                                self.epoch_metrics[metric_name],
+                                weights=np.array(self.epoch_metrics["nb_samples"]),
+                            ),
+                        ),
+                        4,
                     )
-            elif metric_name == "wer":
-                value = float(
-                    np.sum(self.epoch_metrics["edit_words"])
-                    / np.sum(self.epoch_metrics["nb_words"])
-                )
-                if output:
-                    display_values["nb_words"] = int(
-                        np.sum(self.epoch_metrics["nb_words"])
-                    )
-            elif metric_name == "wer_no_punct":
-                value = float(
-                    np.sum(self.epoch_metrics["edit_words_no_punct"])
-                    / np.sum(self.epoch_metrics["nb_words_no_punct"])
-                )
-                if output:
-                    display_values["nb_words_no_punct"] = int(
-                        np.sum(self.epoch_metrics["nb_words_no_punct"])
-                    )
-            elif metric_name in [
-                "loss",
-                "loss_ce",
-            ]:
-                value = float(
-                    np.average(
-                        self.epoch_metrics[metric_name],
-                        weights=np.array(self.epoch_metrics["nb_samples"]),
-                    )
-                )
-            elif value is None:
-                continue
+                    continue
+                case _:
+                    continue
 
+            value = float(
+                np.sum(self.epoch_metrics[num_name])
+                / np.sum(self.epoch_metrics[denom_name])
+            )
+            if output:
+                display_values[denom_name] = int(np.sum(self.epoch_metrics[denom_name]))
             display_values[metric_name] = round(value, 4)
         return display_values
 
-    def compute_metrics(self, values, metric_names):
+    def compute_metrics(
+        self, values: Dict[str, int | float], metric_names: List[str]
+    ) -> Dict[str, List[int | float]]:
         metrics = {
             "nb_samples": [
                 values["nb_samples"],
@@ -125,111 +135,30 @@ class MetricManager:
         }
         if "time" in values:
             metrics["time"] = [values["time"]]
+
+        gt, prediction = values["str_y"], values["str_x"]
         for metric_name in metric_names:
-            if metric_name == "cer":
-                metrics["edit_chars"] = [
-                    edit_cer_from_string(u, v, self.layout_tokens)
-                    for u, v in zip(values["str_y"], values["str_x"])
-                ]
-                metrics["nb_chars"] = [
-                    nb_chars_cer_from_string(gt, self.layout_tokens)
-                    for gt in values["str_y"]
-                ]
-            elif metric_name == "wer":
-                split_gt = [
-                    format_string_for_wer(gt, self.layout_tokens)
-                    for gt in values["str_y"]
-                ]
-                split_pred = [
-                    format_string_for_wer(pred, self.layout_tokens)
-                    for pred in values["str_x"]
-                ]
-                metrics["edit_words"] = [
-                    edit_wer_from_formatted_split_text(gt, pred)
-                    for (gt, pred) in zip(split_gt, split_pred)
-                ]
-                metrics["nb_words"] = [len(gt) for gt in split_gt]
-            elif metric_name == "wer_no_punct":
-                split_gt = [
-                    format_string_for_wer(gt, self.layout_tokens, remove_punct=True)
-                    for gt in values["str_y"]
-                ]
-                split_pred = [
-                    format_string_for_wer(pred, self.layout_tokens, remove_punct=True)
-                    for pred in values["str_x"]
-                ]
-                metrics["edit_words_no_punct"] = [
-                    edit_wer_from_formatted_split_text(gt, pred)
-                    for (gt, pred) in zip(split_gt, split_pred)
-                ]
-                metrics["nb_words_no_punct"] = [len(gt) for gt in split_gt]
-            elif metric_name in [
-                "loss_ce",
-                "loss",
-            ]:
-                metrics[metric_name] = [
-                    values[metric_name],
-                ]
+            match metric_name:
+                case "cer":
+                    metrics["edit_chars"] = list(
+                        map(self.edit_cer_from_string, gt, prediction)
+                    )
+                    metrics["nb_chars"] = list(map(self.nb_chars_cer_from_string, gt))
+                case "wer" | "wer_no_punct":
+                    suffix = metric_name[3:]
+                    split_gt = list(map(self.format_string_for_wer, gt, [bool(suffix)]))
+                    split_pred = list(
+                        map(self.format_string_for_wer, prediction, [bool(suffix)])
+                    )
+                    metrics["edit_words" + suffix] = list(
+                        map(editdistance.eval, split_gt, split_pred)
+                    )
+                    metrics["nb_words" + suffix] = list(map(len, split_gt))
+                case "loss" | "loss_ce":
+                    metrics[metric_name] = [
+                        values[metric_name],
+                    ]
         return metrics
 
-    def get(self, name):
+    def get(self, name: str):
         return self.epoch_metrics[name]
-
-
-def keep_all_but_ner_tokens(str, tokens):
-    """
-    Remove all ner tokens from string
-    """
-    return re.sub("([" + tokens + "])", "", str)
-
-
-def edit_cer_from_string(gt, pred, layout_tokens=None):
-    """
-    Format and compute edit distance between two strings at character level
-    """
-    gt = format_string_for_cer(gt, layout_tokens)
-    pred = format_string_for_cer(pred, layout_tokens)
-    return editdistance.eval(gt, pred)
-
-
-def nb_chars_cer_from_string(gt, layout_tokens=None):
-    """
-    Compute length after formatting of ground truth string
-    """
-    return len(format_string_for_cer(gt, layout_tokens))
-
-
-def format_string_for_wer(str, layout_tokens, remove_punct=False):
-    """
-    Format string for WER computation: remove layout tokens, treat punctuation as word, replace line break by space
-    """
-    if remove_punct:
-        str = re.sub(
-            r"([\[\]{}/\\()\"'&+*=<>?.;:,!\-—_€#%°])", "", str
-        )  # remove punctuation
-    if layout_tokens is not None:
-        str = keep_all_but_ner_tokens(
-            str, layout_tokens
-        )  # remove layout tokens from metric
-    str = re.sub("([ \n])+", " ", str).strip()  # keep only one space character
-    return str.split(" ")
-
-
-def format_string_for_cer(str, layout_tokens):
-    """
-    Format string for CER computation: remove layout tokens and extra spaces
-    """
-    if layout_tokens is not None:
-        str = keep_all_but_ner_tokens(
-            str, layout_tokens
-        )  # remove layout tokens from metric
-    str = re.sub("([\n])+", "\n", str)  # remove consecutive line breaks
-    str = re.sub("([ ])+", " ", str).strip()  # remove consecutive spaces
-    return str
-
-
-def edit_wer_from_formatted_split_text(gt, pred):
-    """
-    Compute edit distance at word level from formatted string as list
-    """
-    return editdistance.eval(gt, pred)