Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (2)
[submodule "nerval"]
path = nerval
url = ../../ner/nerval.git
......@@ -5,17 +5,36 @@ Evaluate a trained DAN model.
import logging
import random
from argparse import ArgumentTypeError
import numpy as np
import torch
import torch.multiprocessing as mp
from dan.bio import convert
from dan.ocr.manager.training import Manager
from dan.ocr.utils import add_metrics_table_row, create_metrics_table, update_config
from dan.utils import read_json
from dan.utils import parse_tokens, read_json
from nerval.evaluate import evaluate
from nerval.parse import parse_bio
from nerval.utils import print_results
logger = logging.getLogger(__name__)
NERVAL_THRESHOLD = 0.30
def parse_threshold(arg):
try:
f = float(arg)
except ValueError:
raise ArgumentTypeError("Must be a floating point number.")
if f < 0 or f > 1:
raise ArgumentTypeError("Must be between 0 and 1.")
return f
def add_evaluate_parser(subcommands) -> None:
parser = subcommands.add_parser(
......@@ -31,10 +50,17 @@ def add_evaluate_parser(subcommands) -> None:
help="Configuration file.",
)
parser.add_argument(
"--nerval-threshold",
help="Distance threshold for the match between gold and predicted entity during Nerval evaluation.",
default=NERVAL_THRESHOLD,
type=parse_threshold,
)
parser.set_defaults(func=run)
def eval(rank, config, mlflow_logging):
def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
......@@ -62,10 +88,12 @@ def eval(rank, config, mlflow_logging):
metric_names.append("ner")
metrics_table = create_metrics_table(metric_names)
all_inferences = {}
for dataset_name in config["dataset"]["datasets"]:
for set_name in ["train", "val", "test"]:
logger.info(f"Evaluating on set `{set_name}`")
metrics = model.evaluate(
metrics, inferences = model.evaluate(
"{}-{}".format(dataset_name, set_name),
[
(dataset_name, set_name),
......@@ -75,11 +103,41 @@ def eval(rank, config, mlflow_logging):
)
add_metrics_table_row(metrics_table, set_name, metrics)
all_inferences[set_name] = inferences
print(metrics_table)
if "ner" not in metric_names:
return
print()
def inferences_to_parsed_bio(attr: str):
bio_values = []
for inference in inferences:
values = getattr(inference, attr)
for value in values:
bio_value = convert(value, ner_tokens=tokens)
bio_values.extend(bio_value.split("\n"))
# Parse this BIO format
return parse_bio(bio_values)
# Evaluate with Nerval
tokens = parse_tokens(config["dataset"]["tokens"])
for set_name, inferences in all_inferences.items():
ground_truths = inferences_to_parsed_bio("ground_truth")
predictions = inferences_to_parsed_bio("prediction")
if not (ground_truths and predictions):
continue
scores = evaluate(ground_truths, predictions, nerval_threshold)
print(set_name)
print_results(scores)
def run(config: dict):
def run(config: dict, nerval_threshold: float):
update_config(config)
mlflow_logging = bool(config.get("mlflow"))
......@@ -94,8 +152,8 @@ def run(config: dict):
):
mp.spawn(
eval,
args=(config, mlflow_logging),
args=(config, nerval_threshold, mlflow_logging),
nprocs=config["training"]["device"]["nb_gpu"],
)
else:
eval(0, config, mlflow_logging)
eval(0, config, nerval_threshold, mlflow_logging)
# -*- coding: utf-8 -*-
import re
from collections import defaultdict
from dataclasses import dataclass
from operator import attrgetter
from pathlib import Path
from typing import Dict, List
......@@ -23,6 +24,12 @@ REGEX_ONLY_ONE_SPACE = re.compile(r"\s+")
METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"}
@dataclass
class Inference:
ground_truth: List[str]
prediction: List[str]
class MetricManager:
def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None):
self.dataset_name: str = dataset_name
......@@ -40,6 +47,9 @@ class MetricManager:
self.metric_names: List[str] = metric_names
self.epoch_metrics = defaultdict(list)
# List of inferences (prediction with their ground truth)
self.inferences = []
def format_string_for_cer(self, text: str, remove_token: bool = False):
"""
Format string for CER computation: remove layout tokens and extra spaces
......@@ -155,6 +165,8 @@ class MetricManager:
metrics["time"] = [values["time"]]
gt, prediction = values["str_y"], values["str_x"]
self.inferences.append(Inference(ground_truth=gt, prediction=prediction))
for metric_name in metric_names:
match metric_name:
case (
......
......@@ -6,7 +6,7 @@ from copy import deepcopy
from enum import Enum
from pathlib import Path
from time import time
from typing import Dict
from typing import Dict, List, Tuple
import numpy as np
import torch
......@@ -20,7 +20,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dan.ocr.manager.metrics import MetricManager
from dan.ocr.manager.metrics import Inference, MetricManager
from dan.ocr.manager.ocr import OCRDatasetManager
from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
from dan.ocr.schedulers import DropoutScheduler
......@@ -750,7 +750,7 @@ class GenericTrainingManager:
def evaluate(
self, custom_name, sets_list, metric_names, mlflow_logging=False
) -> Dict[str, int | float]:
) -> Tuple[Dict[str, int | float], List[Inference]]:
"""
Main loop for evaluation
"""
......@@ -810,7 +810,7 @@ class GenericTrainingManager:
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
return metrics
return metrics, self.metric_manager[custom_name].inferences
def output_pred(self, name):
path = self.paths["results"] / "predict_{}_{}.yaml".format(
......
-e ./nerval
albumentations==1.3.1
arkindex-export==0.1.9
boto3==1.26.124
editdistance==0.6.2
flashlight-text==0.0.4
imageio==2.26.1
imagesize==1.4.1
......@@ -9,7 +9,6 @@ lxml==4.9.3
mdutils==1.6.0
nltk==3.8.1
numpy==1.24.3
prettytable==3.8.0
PyYAML==6.0
scipy==1.10.1
sentencepiece==0.1.99
......
......@@ -3,3 +3,38 @@
| train | 18.89 | 21.05 | 26.67 | 26.67 | 26.67 | 7.14 |
| val | 8.82 | 11.54 | 50.0 | 50.0 | 50.0 | 0.0 |
| test | 2.78 | 3.33 | 14.29 | 14.29 | 14.29 | 0.0 |
train
| tag | predicted | matched | Precision | Recall | F1 | Support |
|:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:|
| Surname | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 |
| Patron | 2 | 0 | 0.0 | 0.0 | 0 | 1 |
| Operai | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 |
| Louche | 2 | 1 | 0.5 | 0.5 | 0.5 | 2 |
| Koala | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 |
| Firstname | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 |
| Chalumeau | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Batiment | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 |
| All | 15 | 12 | 0.8 | 0.857 | 0.828 | 14 |
val
| tag | predicted | matched | Precision | Recall | F1 | Support |
|:---------:|:---------:|:-------:|:---------:|:------:|:----:|:-------:|
| Surname | 1 | 0 | 0.0 | 0.0 | 0 | 1 |
| Patron | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Operai | 1 | 0 | 0.0 | 0.0 | 0 | 1 |
| Louche | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Koala | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Firstname | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Chalumeau | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Batiment | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| All | 8 | 6 | 0.75 | 0.75 | 0.75 | 8 |
test
| tag | predicted | matched | Precision | Recall | F1 | Support |
|:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:|
| Surname | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Louche | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Koala | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Firstname | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Chalumeau | 1 | 0 | 0.0 | 0.0 | 0 | 1 |
| Batiment | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| All | 6 | 5 | 0.833 | 0.833 | 0.833 | 6 |
......@@ -103,7 +103,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
# Use the tmp_path as base folder
evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"
evaluate.run(evaluate_config)
evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD)
# Check that the evaluation results are correct
for split_name, expected_res in zip(
......@@ -129,7 +129,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
# Check the metrics Markdown table
captured_std = capsys.readouterr()
last_printed_lines = captured_std.out.split("\n")[-6:]
last_printed_lines = captured_std.out.split("\n")[-41:]
assert (
"\n".join(last_printed_lines)
== Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()
......