Skip to content
Snippets Groups Projects

Evaluate predictions with nerval

Merged Manon Blanco requested to merge nerval-evaluate into main
All threads resolved!
7 files
+ 94
9
Compare changes
  • Side-by-side
  • Inline
Files
7
+ 12
0
# -*- coding: utf-8 -*-
import re
from collections import defaultdict
from dataclasses import dataclass
from operator import attrgetter
from pathlib import Path
from typing import Dict, List
@@ -23,6 +24,12 @@ REGEX_ONLY_ONE_SPACE = re.compile(r"\s+")
METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"}
@dataclass
class Inference:
ground_truth: List[str]
prediction: List[str]
class MetricManager:
def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None):
self.dataset_name: str = dataset_name
@@ -40,6 +47,9 @@ class MetricManager:
self.metric_names: List[str] = metric_names
self.epoch_metrics = defaultdict(list)
# List of inferences (prediction with their ground truth)
self.inferences = []
def format_string_for_cer(self, text: str, remove_token: bool = False):
"""
Format string for CER computation: remove layout tokens and extra spaces
@@ -155,6 +165,8 @@ class MetricManager:
metrics["time"] = [values["time"]]
gt, prediction = values["str_y"], values["str_x"]
self.inferences.append(Inference(ground_truth=gt, prediction=prediction))
for metric_name in metric_names:
match metric_name:
case (
Loading