diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py index ce48c2f01570de7172d22774cb4d4b08f69fbeb3..041d309e6e9be7919360f8bc29dc721ce404068e 100644 --- a/dan/ocr/evaluate.py +++ b/dan/ocr/evaluate.py @@ -6,12 +6,15 @@ Evaluate a trained DAN model. import logging import random from argparse import ArgumentTypeError +from pathlib import Path +from typing import Dict, List import numpy as np import torch import torch.multiprocessing as mp from dan.bio import convert +from dan.ocr.manager.metrics import Inference from dan.ocr.manager.training import Manager from dan.ocr.utils import add_metrics_table_row, create_metrics_table, update_config from dan.utils import parse_tokens, read_json @@ -60,6 +63,37 @@ def add_evaluate_parser(subcommands) -> None: parser.set_defaults(func=run) +def eval_nerval( + all_inferences: Dict[str, List[Inference]], + tokens: Path, + threshold: float, +): + print("\n#### Nerval evaluation") + + def inferences_to_parsed_bio(attr: str): + bio_values = [] + for inference in inferences: + value = getattr(inference, attr) + bio_value = convert(value, ner_tokens=tokens) + bio_values.extend(bio_value.split("\n")) + + # Parse this BIO format + return parse_bio(bio_values) + + # Evaluate with Nerval + tokens = parse_tokens(tokens) + for split_name, inferences in all_inferences.items(): + ground_truths = inferences_to_parsed_bio("ground_truth") + predictions = inferences_to_parsed_bio("prediction") + + if not (ground_truths and predictions): + continue + + scores = evaluate(ground_truths, predictions, threshold) + print(f"\n##### {split_name}\n") + print_results(scores) + + def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool): torch.manual_seed(0) torch.cuda.manual_seed(0) @@ -105,33 +139,15 @@ def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool): add_metrics_table_row(metrics_table, set_name, metrics) all_inferences[set_name] = inferences + print("\n#### DAN evaluation\n") print(metrics_table) - if "ner" not in metric_names: - return - - def inferences_to_parsed_bio(attr: str): - bio_values = [] - for inference in inferences: - value = getattr(inference, attr) - bio_value = convert(value, ner_tokens=tokens) - bio_values.extend(bio_value.split("\n")) - - # Parse this BIO format - return parse_bio(bio_values) - - # Evaluate with Nerval - tokens = parse_tokens(config["dataset"]["tokens"]) - for set_name, inferences in all_inferences.items(): - ground_truths = inferences_to_parsed_bio("ground_truth") - predictions = inferences_to_parsed_bio("prediction") - - if not (ground_truths and predictions): - continue - - scores = evaluate(ground_truths, predictions, nerval_threshold) - print(f"\n#### {set_name}\n") - print_results(scores) + if "ner" in metric_names: + eval_nerval( + all_inferences, + tokens=config["dataset"]["tokens"], + threshold=nerval_threshold, + ) def run(config: dict, nerval_threshold: float): diff --git a/docs/usage/evaluate/index.md b/docs/usage/evaluate/index.md index 3764e245073b7301499c24ce11e7477afac1f8a4..84039e79bd38f8d6f3a20c64eb81dd8a7e20e635 100644 --- a/docs/usage/evaluate/index.md +++ b/docs/usage/evaluate/index.md @@ -24,37 +24,47 @@ This will, for each evaluated split: ### HTR evaluation +``` +#### DAN evaluation + | Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) | | :---: | :-----------: | :-------: | :-----------: | :-------: | :----------------: | | train | x | x | x | x | x | | val | x | x | x | x | x | | test | x | x | x | x | x | +``` ### HTR and NER evaluation +``` +#### DAN evaluation + | Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) | NER | | :---: | :-----------: | :-------: | :-----------: | :-------: | :----------------: | :-: | | train | x | x | x | x | x | x | | val | x | x | x | x | x | x | | test | x | x | x | x | x | x | -#### train +#### Nerval evaluation + +##### train | tag | predicted | matched | Precision | Recall | F1 | Support | | :-----: | :-------: | :-----: | :-------: | :----: | :-: | :-----: | | Surname | x | x | x | x | x | x | | All | x | x | x | x | x | x | -#### val +##### val | tag | predicted | matched | Precision | Recall | F1 | Support | | :-----: | :-------: | :-----: | :-------: | :----: | :-: | :-----: | | Surname | x | x | x | x | x | x | | All | x | x | x | x | x | x | -#### test +##### test | tag | predicted | matched | Precision | Recall | F1 | Support | | :-----: | :-------: | :-----: | :-------: | :----: | :-: | :-----: | | Surname | x | x | x | x | x | x | | All | x | x | x | x | x | x | +``` diff --git a/tests/data/evaluate/metrics_table.md b/tests/data/evaluate/metrics_table.md index b33a3bf04d2018eef678bbc37841c8753a314366..76d976c389571a16bb42e4f354d5b208f09faa64 100644 --- a/tests/data/evaluate/metrics_table.md +++ b/tests/data/evaluate/metrics_table.md @@ -1,10 +1,14 @@ +#### DAN evaluation + | Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) | NER | |:-----:|:-------------:|:---------:|:-------------:|:---------:|:------------------:|:----:| | train | 18.89 | 21.05 | 26.67 | 26.67 | 26.67 | 7.14 | | val | 8.82 | 11.54 | 50.0 | 50.0 | 50.0 | 0.0 | | test | 2.78 | 3.33 | 14.29 | 14.29 | 14.29 | 0.0 | -#### train +#### Nerval evaluation + +##### train | tag | predicted | matched | Precision | Recall | F1 | Support | |:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:| @@ -18,7 +22,7 @@ | Batiment | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 | | All | 15 | 12 | 0.8 | 0.857 | 0.828 | 14 | -#### val +##### val | tag | predicted | matched | Precision | Recall | F1 | Support | |:---------:|:---------:|:-------:|:---------:|:------:|:----:|:-------:| @@ -32,7 +36,7 @@ | Batiment | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | | All | 8 | 6 | 0.75 | 0.75 | 0.75 | 8 | -#### test +##### test | tag | predicted | matched | Precision | Recall | F1 | Support | |:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:| diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 871c415d2d2aa77634de06b050395d006b2037f9..0bdf51968cf985a7c1fa1dcac00aedbe4b474091 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -129,7 +129,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config): # Check the metrics Markdown table captured_std = capsys.readouterr() - last_printed_lines = captured_std.out.split("\n")[-46:] + last_printed_lines = captured_std.out.split("\n")[10:] assert ( "\n".join(last_printed_lines) == Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()