From 0930e59dc566484388f0cc7cbf6f2a4ed13b956d Mon Sep 17 00:00:00 2001
From: manonBlanco <blanco@teklia.com>
Date: Fri, 22 Dec 2023 14:33:40 +0100
Subject: [PATCH] Display 5 worst predictions at the end of evaluation

---
 dan/ocr/evaluate.py                  | 39 ++++++++++++++++++++++++++++
 dan/ocr/manager/metrics.py           |  2 ++
 dan/ocr/manager/training.py          | 14 +++++++---
 tests/data/evaluate/metrics_table.md | 17 ++++++++++++
 4 files changed, 69 insertions(+), 3 deletions(-)

diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py
index 041d309e..0ee69128 100644
--- a/dan/ocr/evaluate.py
+++ b/dan/ocr/evaluate.py
@@ -6,12 +6,16 @@ Evaluate a trained DAN model.
 import logging
 import random
 from argparse import ArgumentTypeError
+from itertools import chain
+from operator import attrgetter
 from pathlib import Path
 from typing import Dict, List
 
 import numpy as np
 import torch
 import torch.multiprocessing as mp
+from edlib import align, getNiceAlignment
+from prettytable import MARKDOWN, PrettyTable
 
 from dan.bio import convert
 from dan.ocr.manager.metrics import Inference
@@ -25,6 +29,7 @@ from nerval.utils import print_results
 logger = logging.getLogger(__name__)
 
 NERVAL_THRESHOLD = 0.30
+NB_WORST_PREDICTIONS = 5
 
 
 def parse_threshold(arg):
@@ -63,6 +68,38 @@ def add_evaluate_parser(subcommands) -> None:
     parser.set_defaults(func=run)
 
 
+def print_worst_predictions(all_inferences: Dict[str, List[Inference]]):
+    table = PrettyTable(
+        field_names=[
+            "Image name",
+            "WER",
+            "Alignment between ground truth - prediction",
+        ]
+    )
+    table.set_style(MARKDOWN)
+
+    worst_inferences = sorted(
+        chain.from_iterable(all_inferences.values()),
+        key=attrgetter("wer"),
+        reverse=True,
+    )[:NB_WORST_PREDICTIONS]
+    for inference in worst_inferences:
+        alignment = getNiceAlignment(
+            align(
+                inference.ground_truth,
+                inference.prediction,
+                task="path",
+            ),
+            inference.ground_truth,
+            inference.prediction,
+        )
+        alignment_str = f'{alignment["query_aligned"]}\n{alignment["matched_aligned"]}\n{alignment["target_aligned"]}'
+        table.add_row([inference.image, inference.wer, alignment_str])
+
+    print(f"\n#### {NB_WORST_PREDICTIONS} worst prediction(s)\n")
+    print(table)
+
+
 def eval_nerval(
     all_inferences: Dict[str, List[Inference]],
     tokens: Path,
@@ -149,6 +186,8 @@ def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool):
             threshold=nerval_threshold,
         )
 
+    print_worst_predictions(all_inferences)
+
 
 def run(config: dict, nerval_threshold: float):
     update_config(config)
diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py
index d7b3f6f2..fc088295 100644
--- a/dan/ocr/manager/metrics.py
+++ b/dan/ocr/manager/metrics.py
@@ -26,8 +26,10 @@ METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"}
 
 @dataclass
 class Inference:
+    image: str
     ground_truth: str
     prediction: str
+    wer: float
 
 
 class MetricManager:
diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py
index 7416768c..5c7bfb91 100644
--- a/dan/ocr/manager/training.py
+++ b/dan/ocr/manager/training.py
@@ -4,7 +4,7 @@ import os
 import random
 from copy import deepcopy
 from enum import Enum
-from itertools import starmap
+from itertools import repeat, starmap
 from pathlib import Path
 from time import time
 from typing import Dict, List, Tuple
@@ -769,7 +769,9 @@ class GenericTrainingManager:
             tokens=self.tokens,
         )
 
-        # Keep inferences in memory to evaluate with Nerval
+        # Keep inferences in memory to:
+        # - evaluate with Nerval
+        # - display worst predictions
         inferences = []
 
         with tqdm(total=len(loader.dataset)) as pbar:
@@ -798,7 +800,13 @@ class GenericTrainingManager:
 
                     inferences.extend(
                         starmap(
-                            Inference, zip(batch_values["str_y"], batch_values["str_x"])
+                            Inference,
+                            zip(
+                                batch_data["names"],
+                                batch_values["str_y"],
+                                batch_values["str_x"],
+                                repeat(display_values["wer"]),
+                            ),
                         )
                     )
 
diff --git a/tests/data/evaluate/metrics_table.md b/tests/data/evaluate/metrics_table.md
index 76d976c3..b48adcf4 100644
--- a/tests/data/evaluate/metrics_table.md
+++ b/tests/data/evaluate/metrics_table.md
@@ -47,3 +47,20 @@
 | Chalumeau |     1     |    0    |    0.0    |  0.0   |   0   |    1    |
 |  Batiment |     1     |    1    |    1.0    |  1.0   |  1.0  |    1    |
 |    All    |     6     |    5    |   0.833   | 0.833  | 0.833 |    6    |
+
+#### 5 worst prediction(s)
+
+|                Image name                |  WER   |        Alignment between ground truth - prediction        |
+|:----------------------------------------:|:------:|:---------------------------------------------------------:|
+| 2c242f5c-e979-43c4-b6f2-a6d4815b651d.png |  0.5   |             ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331            |
+|                                          |        |             |.||||||||||||||||||||||||.||||.||            |
+|                                          |        |             Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31            |
+| 0dfe8bcd-ed0b-453e-bf19-cc697012296e.png | 0.2667 |      ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle-------      |
+|                                          |        |      ||||||||||||||||||||||||.|||||||||||.||.-------      |
+|                                          |        |      ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376      |
+| ffdec445-7f14-4f5f-be44-68d0844d0df1.png | 0.1429 |            ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère           |
+|                                          |        |            |||||||||||||||||||||||.||||||||||||           |
+|                                          |        |            ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère           |
+| 0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png | 0.125  | ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ-------12241 |
+|                                          |        | |||||||||||||||||||||||||||||||||||||||||||||-------||||| |
+|                                          |        | ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241 |
-- 
GitLab