Skip to content
Snippets Groups Projects

Display 5 worst predictions at the end of evaluation

Merged Manon Blanco requested to merge display-worst-predictions into main
All threads resolved!
4 files
+ 62
32
Compare changes
  • Side-by-side
  • Inline
Files
4
+ 39
0
@@ -6,12 +6,16 @@ Evaluate a trained DAN model.
import logging
import random
from argparse import ArgumentTypeError
from itertools import chain
from operator import attrgetter
from pathlib import Path
from typing import Dict, List
import numpy as np
import torch
import torch.multiprocessing as mp
from edlib import align, getNiceAlignment
from prettytable import MARKDOWN, PrettyTable
from dan.bio import convert
from dan.ocr.manager.metrics import Inference
@@ -25,6 +29,7 @@ from nerval.utils import print_results
logger = logging.getLogger(__name__)
NERVAL_THRESHOLD = 0.30
NB_WORST_PREDICTIONS = 5
def parse_threshold(value: str) -> float:
@@ -66,6 +71,38 @@ def add_evaluate_parser(subcommands) -> None:
parser.set_defaults(func=run)
def print_worst_predictions(all_inferences: Dict[str, List[Inference]]):
table = PrettyTable(
field_names=[
"Image name",
"WER",
"Alignment between ground truth - prediction",
]
)
table.set_style(MARKDOWN)
worst_inferences = sorted(
chain.from_iterable(all_inferences.values()),
key=attrgetter("wer"),
reverse=True,
)[:NB_WORST_PREDICTIONS]
for inference in worst_inferences:
alignment = getNiceAlignment(
align(
inference.ground_truth,
inference.prediction,
task="path",
),
inference.ground_truth,
inference.prediction,
)
alignment_str = f'{alignment["query_aligned"]}\n{alignment["matched_aligned"]}\n{alignment["target_aligned"]}'
table.add_row([inference.image, round(inference.wer * 100, 2), alignment_str])
print(f"\n#### {NB_WORST_PREDICTIONS} worst prediction(s)\n")
print(table)
def eval_nerval(
all_inferences: Dict[str, List[Inference]],
tokens: Path,
@@ -159,6 +196,8 @@ def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool):
threshold=nerval_threshold,
)
print_worst_predictions(all_inferences)
def run(config: dict, nerval_threshold: float):
update_config(config)
Loading