From b582b4bb4f4508c5e6d167e0dc2766179d604f4f Mon Sep 17 00:00:00 2001 From: manonBlanco <blanco@teklia.com> Date: Fri, 22 Dec 2023 14:33:40 +0100 Subject: [PATCH] Display 5 worst predictions at the end of evaluation --- dan/ocr/evaluate.py | 39 ++++++++++++++++++++++++++++ dan/ocr/manager/metrics.py | 2 ++ dan/ocr/manager/training.py | 13 ++++++++-- tests/data/evaluate/metrics_table.md | 17 ++++++++++++ 4 files changed, 69 insertions(+), 2 deletions(-) diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py index 079b0e5d..e101832a 100644 --- a/dan/ocr/evaluate.py +++ b/dan/ocr/evaluate.py @@ -6,12 +6,16 @@ Evaluate a trained DAN model. import logging import random from argparse import ArgumentTypeError +from itertools import chain +from operator import attrgetter from pathlib import Path from typing import Dict, List import numpy as np import torch import torch.multiprocessing as mp +from edlib import align, getNiceAlignment +from prettytable import MARKDOWN, PrettyTable from dan.bio import convert from dan.ocr.manager.metrics import Inference @@ -25,6 +29,7 @@ from nerval.utils import print_results logger = logging.getLogger(__name__) NERVAL_THRESHOLD = 0.30 +NB_WORST_PREDICTIONS = 5 def parse_threshold(value: str) -> float: @@ -66,6 +71,38 @@ def add_evaluate_parser(subcommands) -> None: parser.set_defaults(func=run) +def print_worst_predictions(all_inferences: Dict[str, List[Inference]]): + table = PrettyTable( + field_names=[ + "Image name", + "WER", + "Alignment between ground truth - prediction", + ] + ) + table.set_style(MARKDOWN) + + worst_inferences = sorted( + chain.from_iterable(all_inferences.values()), + key=attrgetter("wer"), + reverse=True, + )[:NB_WORST_PREDICTIONS] + for inference in worst_inferences: + alignment = getNiceAlignment( + align( + inference.ground_truth, + inference.prediction, + task="path", + ), + inference.ground_truth, + inference.prediction, + ) + alignment_str = f'{alignment["query_aligned"]}\n{alignment["matched_aligned"]}\n{alignment["target_aligned"]}' + table.add_row([inference.image, inference.wer, alignment_str]) + + print(f"\n#### {NB_WORST_PREDICTIONS} worst prediction(s)\n") + print(table) + + def eval_nerval( all_inferences: Dict[str, List[Inference]], tokens: Path, @@ -159,6 +196,8 @@ def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool): threshold=nerval_threshold, ) + print_worst_predictions(all_inferences) + def run(config: dict, nerval_threshold: float): update_config(config) diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py index 988c90d0..07cece06 100644 --- a/dan/ocr/manager/metrics.py +++ b/dan/ocr/manager/metrics.py @@ -29,8 +29,10 @@ class Inference(NamedTuple): inferring again when we need to compute new metrics """ + image: str ground_truth: str prediction: str + wer: float class MetricManager: diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index ad35f823..d436ead1 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -4,6 +4,7 @@ import os import random from copy import deepcopy from enum import Enum +from itertools import repeat from pathlib import Path from time import time from typing import Dict, List, Tuple @@ -768,7 +769,9 @@ class GenericTrainingManager: tokens=self.tokens, ) - # Keep inferences in memory to evaluate with Nerval + # Keep inferences in memory to: + # - evaluate with Nerval + # - display worst predictions inferences = [] with tqdm(total=len(loader.dataset)) as pbar: @@ -796,7 +799,13 @@ class GenericTrainingManager: pbar.update(len(batch_data["names"]) * self.nb_workers) inferences.extend( - map(Inference, batch_values["str_y"], batch_values["str_x"]) + map( + Inference, + batch_data["names"], + batch_values["str_y"], + batch_values["str_x"], + repeat(display_values["wer"]), + ) ) # log metrics in MLflow diff --git a/tests/data/evaluate/metrics_table.md b/tests/data/evaluate/metrics_table.md index 107bef41..8e562331 100644 --- a/tests/data/evaluate/metrics_table.md +++ b/tests/data/evaluate/metrics_table.md @@ -47,3 +47,20 @@ | Chalumeau | 1 | 0 | 0.0 | 0.0 | 0 | 1 | | Batiment | 1 | 1 | 100.0 | 100.0 | 100.0 | 1 | | All | 6 | 5 | 83.33 | 83.33 | 83.33 | 6 | + +#### 5 worst prediction(s) + +| Image name | WER | Alignment between ground truth - prediction | +|:----------------------------------------:|:------:|:---------------------------------------------------------:| +| 2c242f5c-e979-43c4-b6f2-a6d4815b651d.png | 0.5 | ⓈA â’»Charles â’·11 â“P â’¸C â“€F â“„A â“…14331 | +| | | |.||||||||||||||||||||||||.||||.|| | +| | | Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31 | +| 0dfe8bcd-ed0b-453e-bf19-cc697012296e.png | 0.2667 | ⓈTemplié â’»Marcelle â’·93 â“J â“€ch â“„E dachyle------- | +| | | ||||||||||||||||||||||||.|||||||||||.||.------- | +| | | ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376 | +| ffdec445-7f14-4f5f-be44-68d0844d0df1.png | 0.1429 | ⓈNaudin â’»Marie â’·53 â“S â’¸V â“€Belle mère | +| | | |||||||||||||||||||||||.|||||||||||| | +| | | ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère | +| 0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png | 0.125 | ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…-------12241 | +| | | |||||||||||||||||||||||||||||||||||||||||||||-------||||| | +| | | ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241 | -- GitLab