Skip to content
Snippets Groups Projects
Commit 5af57960 authored by Manon Blanco's avatar Manon Blanco
Browse files

Improve code and display

parent c1ab7ab8
No related branches found
No related tags found
No related merge requests found
......@@ -6,12 +6,15 @@ Evaluate a trained DAN model.
import logging
import random
from argparse import ArgumentTypeError
from pathlib import Path
from typing import Dict, List
import numpy as np
import torch
import torch.multiprocessing as mp
from dan.bio import convert
from dan.ocr.manager.metrics import Inference
from dan.ocr.manager.training import Manager
from dan.ocr.utils import add_metrics_table_row, create_metrics_table, update_config
from dan.utils import parse_tokens, read_json
......@@ -60,6 +63,37 @@ def add_evaluate_parser(subcommands) -> None:
parser.set_defaults(func=run)
def eval_nerval(
all_inferences: Dict[str, List[Inference]],
tokens: Path,
threshold: float,
):
print("\n#### Nerval evaluation")
def inferences_to_parsed_bio(attr: str):
bio_values = []
for inference in inferences:
value = getattr(inference, attr)
bio_value = convert(value, ner_tokens=tokens)
bio_values.extend(bio_value.split("\n"))
# Parse this BIO format
return parse_bio(bio_values)
# Evaluate with Nerval
tokens = parse_tokens(tokens)
for split_name, inferences in all_inferences.items():
ground_truths = inferences_to_parsed_bio("ground_truth")
predictions = inferences_to_parsed_bio("prediction")
if not (ground_truths and predictions):
continue
scores = evaluate(ground_truths, predictions, threshold)
print(f"\n##### {split_name}\n")
print_results(scores)
def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
......@@ -105,33 +139,15 @@ def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool):
add_metrics_table_row(metrics_table, set_name, metrics)
all_inferences[set_name] = inferences
print("\n#### DAN evaluation\n")
print(metrics_table)
if "ner" not in metric_names:
return
def inferences_to_parsed_bio(attr: str):
bio_values = []
for inference in inferences:
value = getattr(inference, attr)
bio_value = convert(value, ner_tokens=tokens)
bio_values.extend(bio_value.split("\n"))
# Parse this BIO format
return parse_bio(bio_values)
# Evaluate with Nerval
tokens = parse_tokens(config["dataset"]["tokens"])
for set_name, inferences in all_inferences.items():
ground_truths = inferences_to_parsed_bio("ground_truth")
predictions = inferences_to_parsed_bio("prediction")
if not (ground_truths and predictions):
continue
scores = evaluate(ground_truths, predictions, nerval_threshold)
print(f"\n#### {set_name}\n")
print_results(scores)
if "ner" in metric_names:
eval_nerval(
all_inferences,
tokens=config["dataset"]["tokens"],
threshold=nerval_threshold,
)
def run(config: dict, nerval_threshold: float):
......
......@@ -24,37 +24,47 @@ This will, for each evaluated split:
### HTR evaluation
```
#### DAN evaluation
| Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) |
| :---: | :-----------: | :-------: | :-----------: | :-------: | :----------------: |
| train | x | x | x | x | x |
| val | x | x | x | x | x |
| test | x | x | x | x | x |
```
### HTR and NER evaluation
```
#### DAN evaluation
| Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) | NER |
| :---: | :-----------: | :-------: | :-----------: | :-------: | :----------------: | :-: |
| train | x | x | x | x | x | x |
| val | x | x | x | x | x | x |
| test | x | x | x | x | x | x |
#### train
#### Nerval evaluation
##### train
| tag | predicted | matched | Precision | Recall | F1 | Support |
| :-----: | :-------: | :-----: | :-------: | :----: | :-: | :-----: |
| Surname | x | x | x | x | x | x |
| All | x | x | x | x | x | x |
#### val
##### val
| tag | predicted | matched | Precision | Recall | F1 | Support |
| :-----: | :-------: | :-----: | :-------: | :----: | :-: | :-----: |
| Surname | x | x | x | x | x | x |
| All | x | x | x | x | x | x |
#### test
##### test
| tag | predicted | matched | Precision | Recall | F1 | Support |
| :-----: | :-------: | :-----: | :-------: | :----: | :-: | :-----: |
| Surname | x | x | x | x | x | x |
| All | x | x | x | x | x | x |
```
#### DAN evaluation
| Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) | NER |
|:-----:|:-------------:|:---------:|:-------------:|:---------:|:------------------:|:----:|
| train | 18.89 | 21.05 | 26.67 | 26.67 | 26.67 | 7.14 |
| val | 8.82 | 11.54 | 50.0 | 50.0 | 50.0 | 0.0 |
| test | 2.78 | 3.33 | 14.29 | 14.29 | 14.29 | 0.0 |
#### train
#### Nerval evaluation
##### train
| tag | predicted | matched | Precision | Recall | F1 | Support |
|:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:|
......@@ -18,7 +22,7 @@
| Batiment | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 |
| All | 15 | 12 | 0.8 | 0.857 | 0.828 | 14 |
#### val
##### val
| tag | predicted | matched | Precision | Recall | F1 | Support |
|:---------:|:---------:|:-------:|:---------:|:------:|:----:|:-------:|
......@@ -32,7 +36,7 @@
| Batiment | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| All | 8 | 6 | 0.75 | 0.75 | 0.75 | 8 |
#### test
##### test
| tag | predicted | matched | Precision | Recall | F1 | Support |
|:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:|
......
......@@ -129,7 +129,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
# Check the metrics Markdown table
captured_std = capsys.readouterr()
last_printed_lines = captured_std.out.split("\n")[-46:]
last_printed_lines = captured_std.out.split("\n")[10:]
assert (
"\n".join(last_printed_lines)
== Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment