Skip to content
Snippets Groups Projects
Commit 0930e59d authored by Manon Blanco's avatar Manon Blanco
Browse files

Display 5 worst predictions at the end of evaluation

parent 5af57960
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !341. Comments created here will be created in the context of that merge request.
......@@ -6,12 +6,16 @@ Evaluate a trained DAN model.
import logging
import random
from argparse import ArgumentTypeError
from itertools import chain
from operator import attrgetter
from pathlib import Path
from typing import Dict, List
import numpy as np
import torch
import torch.multiprocessing as mp
from edlib import align, getNiceAlignment
from prettytable import MARKDOWN, PrettyTable
from dan.bio import convert
from dan.ocr.manager.metrics import Inference
......@@ -25,6 +29,7 @@ from nerval.utils import print_results
logger = logging.getLogger(__name__)
NERVAL_THRESHOLD = 0.30
NB_WORST_PREDICTIONS = 5
def parse_threshold(arg):
......@@ -63,6 +68,38 @@ def add_evaluate_parser(subcommands) -> None:
parser.set_defaults(func=run)
def print_worst_predictions(all_inferences: Dict[str, List[Inference]]):
table = PrettyTable(
field_names=[
"Image name",
"WER",
"Alignment between ground truth - prediction",
]
)
table.set_style(MARKDOWN)
worst_inferences = sorted(
chain.from_iterable(all_inferences.values()),
key=attrgetter("wer"),
reverse=True,
)[:NB_WORST_PREDICTIONS]
for inference in worst_inferences:
alignment = getNiceAlignment(
align(
inference.ground_truth,
inference.prediction,
task="path",
),
inference.ground_truth,
inference.prediction,
)
alignment_str = f'{alignment["query_aligned"]}\n{alignment["matched_aligned"]}\n{alignment["target_aligned"]}'
table.add_row([inference.image, inference.wer, alignment_str])
print(f"\n#### {NB_WORST_PREDICTIONS} worst prediction(s)\n")
print(table)
def eval_nerval(
all_inferences: Dict[str, List[Inference]],
tokens: Path,
......@@ -149,6 +186,8 @@ def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool):
threshold=nerval_threshold,
)
print_worst_predictions(all_inferences)
def run(config: dict, nerval_threshold: float):
update_config(config)
......
......@@ -26,8 +26,10 @@ METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"}
@dataclass
class Inference:
image: str
ground_truth: str
prediction: str
wer: float
class MetricManager:
......
......@@ -4,7 +4,7 @@ import os
import random
from copy import deepcopy
from enum import Enum
from itertools import starmap
from itertools import repeat, starmap
from pathlib import Path
from time import time
from typing import Dict, List, Tuple
......@@ -769,7 +769,9 @@ class GenericTrainingManager:
tokens=self.tokens,
)
# Keep inferences in memory to evaluate with Nerval
# Keep inferences in memory to:
# - evaluate with Nerval
# - display worst predictions
inferences = []
with tqdm(total=len(loader.dataset)) as pbar:
......@@ -798,7 +800,13 @@ class GenericTrainingManager:
inferences.extend(
starmap(
Inference, zip(batch_values["str_y"], batch_values["str_x"])
Inference,
zip(
batch_data["names"],
batch_values["str_y"],
batch_values["str_x"],
repeat(display_values["wer"]),
),
)
)
......
......@@ -47,3 +47,20 @@
| Chalumeau | 1 | 0 | 0.0 | 0.0 | 0 | 1 |
| Batiment | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| All | 6 | 5 | 0.833 | 0.833 | 0.833 | 6 |
#### 5 worst prediction(s)
| Image name | WER | Alignment between ground truth - prediction |
|:----------------------------------------:|:------:|:---------------------------------------------------------:|
| 2c242f5c-e979-43c4-b6f2-a6d4815b651d.png | 0.5 | ⓈA ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF ⓄA Ⓟ14331 |
| | | |.||||||||||||||||||||||||.||||.|| |
| | | Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31 |
| 0dfe8bcd-ed0b-453e-bf19-cc697012296e.png | 0.2667 | ⓈTemplié ⒻMarcelle Ⓑ93 ⓁJ Ⓚch ⓄE dachyle------- |
| | | ||||||||||||||||||||||||.|||||||||||.||.------- |
| | | ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376 |
| ffdec445-7f14-4f5f-be44-68d0844d0df1.png | 0.1429 | ⓈNaudin ⒻMarie Ⓑ53 ⓁS ⒸV ⓀBelle mère |
| | | |||||||||||||||||||||||.|||||||||||| |
| | | ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère |
| 0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84.png | 0.125 | ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier Ⓟ-------12241 |
| | | |||||||||||||||||||||||||||||||||||||||||||||-------||||| |
| | | ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241 |
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment