From 0930e59dc566484388f0cc7cbf6f2a4ed13b956d Mon Sep 17 00:00:00 2001 From: manonBlanco <blanco@teklia.com> Date: Fri, 22 Dec 2023 14:33:40 +0100 Subject: [PATCH] Display 5 worst predictions at the end of evaluation --- dan/ocr/evaluate.py | 39 ++++++++++++++++++++++++++++ dan/ocr/manager/metrics.py | 2 ++ dan/ocr/manager/training.py | 14 +++++++--- tests/data/evaluate/metrics_table.md | 17 ++++++++++++ 4 files changed, 69 insertions(+), 3 deletions(-) diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py index 041d309e..0ee69128 100644 --- a/dan/ocr/evaluate.py +++ b/dan/ocr/evaluate.py @@ -6,12 +6,16 @@ Evaluate a trained DAN model. import logging import random from argparse import ArgumentTypeError +from itertools import chain +from operator import attrgetter from pathlib import Path from typing import Dict, List import numpy as np import torch import torch.multiprocessing as mp +from edlib import align, getNiceAlignment +from prettytable import MARKDOWN, PrettyTable from dan.bio import convert from dan.ocr.manager.metrics import Inference @@ -25,6 +29,7 @@ from nerval.utils import print_results logger = logging.getLogger(__name__) NERVAL_THRESHOLD = 0.30 +NB_WORST_PREDICTIONS = 5 def parse_threshold(arg): @@ -63,6 +68,38 @@ def add_evaluate_parser(subcommands) -> None: parser.set_defaults(func=run) +def print_worst_predictions(all_inferences: Dict[str, List[Inference]]): + table = PrettyTable( + field_names=[ + "Image name", + "WER", + "Alignment between ground truth - prediction", + ] + ) + table.set_style(MARKDOWN) + + worst_inferences = sorted( + chain.from_iterable(all_inferences.values()), + key=attrgetter("wer"), + reverse=True, + )[:NB_WORST_PREDICTIONS] + for inference in worst_inferences: + alignment = getNiceAlignment( + align( + inference.ground_truth, + inference.prediction, + task="path", + ), + inference.ground_truth, + inference.prediction, + ) + alignment_str = f'{alignment["query_aligned"]}\n{alignment["matched_aligned"]}\n{alignment["target_aligned"]}' + table.add_row([inference.image, inference.wer, alignment_str]) + + print(f"\n#### {NB_WORST_PREDICTIONS} worst prediction(s)\n") + print(table) + + def eval_nerval( all_inferences: Dict[str, List[Inference]], tokens: Path, @@ -149,6 +186,8 @@ def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool): threshold=nerval_threshold, ) + print_worst_predictions(all_inferences) + def run(config: dict, nerval_threshold: float): update_config(config) diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py index d7b3f6f2..fc088295 100644 --- a/dan/ocr/manager/metrics.py +++ b/dan/ocr/manager/metrics.py @@ -26,8 +26,10 @@ METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"} @dataclass class Inference: + image: str ground_truth: str prediction: str + wer: float class MetricManager: diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index 7416768c..5c7bfb91 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -4,7 +4,7 @@ import os import random from copy import deepcopy from enum import Enum -from itertools import starmap +from itertools import repeat, starmap from pathlib import Path from time import time from typing import Dict, List, Tuple @@ -769,7 +769,9 @@ class GenericTrainingManager: tokens=self.tokens, ) - # Keep inferences in memory to evaluate with Nerval + # Keep inferences in memory to: + # - evaluate with Nerval + # - display worst predictions inferences = [] with tqdm(total=len(loader.dataset)) as pbar: @@ -798,7 +800,13 @@ class GenericTrainingManager: inferences.extend( starmap( - Inference, zip(batch_values["str_y"], batch_values["str_x"]) + Inference, + zip( + batch_data["names"], + batch_values["str_y"], + batch_values["str_x"], + repeat(display_values["wer"]), + ), ) ) diff --git a/tests/data/evaluate/metrics_table.md b/tests/data/evaluate/metrics_table.md index 76d976c3..b48adcf4 100644 --- a/tests/data/evaluate/metrics_table.md +++ b/tests/data/evaluate/metrics_table.md @@ -47,3 +47,20 @@ | Chalumeau | 1 | 0 | 0.0 | 0.0 | 0 | 1 | | Batiment | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | | All | 6 | 5 | 0.833 | 0.833 | 0.833 | 6 | + +#### 5 worst prediction(s) + +| Image name | WER | Alignment between ground truth - prediction | +|:----------------------------------------:|:------:|:---------------------------------------------------------:| +| 2c242f5c-e979-43c4-b6f2-a6d4815b651d.png | 0.5 | ⓈA â’»Charles â’·11 â“P â’¸C â“€F â“„A â“…14331 | +| | | |.||||||||||||||||||||||||.||||.|| | +| | | Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31 | +| 0dfe8bcd-ed0b-453e-bf19-cc697012296e.png | 0.2667 | ⓈTemplié â’»Marcelle â’·93 â“J â“€ch â“„E dachyle------- | +| | | ||||||||||||||||||||||||.|||||||||||.||.------- | +| | | ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376 | +| ffdec445-7f14-4f5f-be44-68d0844d0df1.png | 0.1429 | ⓈNaudin â’»Marie â’·53 â“S â’¸V â“€Belle mère | +| | | |||||||||||||||||||||||.|||||||||||| | +| | | ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère | +| 0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png | 0.125 | ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…-------12241 | +| | | |||||||||||||||||||||||||||||||||||||||||||||-------||||| | +| | | ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241 | -- GitLab