Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (2)
[submodule "nerval"]
path = nerval
url = ../../ner/nerval.git
...@@ -5,17 +5,36 @@ Evaluate a trained DAN model. ...@@ -5,17 +5,36 @@ Evaluate a trained DAN model.
import logging import logging
import random import random
from argparse import ArgumentTypeError
import numpy as np import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from dan.bio import convert
from dan.ocr.manager.training import Manager from dan.ocr.manager.training import Manager
from dan.ocr.utils import add_metrics_table_row, create_metrics_table, update_config from dan.ocr.utils import add_metrics_table_row, create_metrics_table, update_config
from dan.utils import read_json from dan.utils import parse_tokens, read_json
from nerval.evaluate import evaluate
from nerval.parse import parse_bio
from nerval.utils import print_results
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
NERVAL_THRESHOLD = 0.30
def parse_threshold(arg):
try:
f = float(arg)
except ValueError:
raise ArgumentTypeError("Must be a floating point number.")
if f < 0 or f > 1:
raise ArgumentTypeError("Must be between 0 and 1.")
return f
def add_evaluate_parser(subcommands) -> None: def add_evaluate_parser(subcommands) -> None:
parser = subcommands.add_parser( parser = subcommands.add_parser(
...@@ -31,10 +50,17 @@ def add_evaluate_parser(subcommands) -> None: ...@@ -31,10 +50,17 @@ def add_evaluate_parser(subcommands) -> None:
help="Configuration file.", help="Configuration file.",
) )
parser.add_argument(
"--nerval-threshold",
help="Distance threshold for the match between gold and predicted entity during Nerval evaluation.",
default=NERVAL_THRESHOLD,
type=parse_threshold,
)
parser.set_defaults(func=run) parser.set_defaults(func=run)
def eval(rank, config, mlflow_logging): def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool):
torch.manual_seed(0) torch.manual_seed(0)
torch.cuda.manual_seed(0) torch.cuda.manual_seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -62,10 +88,12 @@ def eval(rank, config, mlflow_logging): ...@@ -62,10 +88,12 @@ def eval(rank, config, mlflow_logging):
metric_names.append("ner") metric_names.append("ner")
metrics_table = create_metrics_table(metric_names) metrics_table = create_metrics_table(metric_names)
all_inferences = {}
for dataset_name in config["dataset"]["datasets"]: for dataset_name in config["dataset"]["datasets"]:
for set_name in ["train", "val", "test"]: for set_name in ["train", "val", "test"]:
logger.info(f"Evaluating on set `{set_name}`") logger.info(f"Evaluating on set `{set_name}`")
metrics = model.evaluate( metrics, inferences = model.evaluate(
"{}-{}".format(dataset_name, set_name), "{}-{}".format(dataset_name, set_name),
[ [
(dataset_name, set_name), (dataset_name, set_name),
...@@ -75,11 +103,41 @@ def eval(rank, config, mlflow_logging): ...@@ -75,11 +103,41 @@ def eval(rank, config, mlflow_logging):
) )
add_metrics_table_row(metrics_table, set_name, metrics) add_metrics_table_row(metrics_table, set_name, metrics)
all_inferences[set_name] = inferences
print(metrics_table) print(metrics_table)
if "ner" not in metric_names:
return
print()
def inferences_to_parsed_bio(attr: str):
bio_values = []
for inference in inferences:
values = getattr(inference, attr)
for value in values:
bio_value = convert(value, ner_tokens=tokens)
bio_values.extend(bio_value.split("\n"))
# Parse this BIO format
return parse_bio(bio_values)
# Evaluate with Nerval
tokens = parse_tokens(config["dataset"]["tokens"])
for set_name, inferences in all_inferences.items():
ground_truths = inferences_to_parsed_bio("ground_truth")
predictions = inferences_to_parsed_bio("prediction")
if not (ground_truths and predictions):
continue
scores = evaluate(ground_truths, predictions, nerval_threshold)
print(set_name)
print_results(scores)
def run(config: dict): def run(config: dict, nerval_threshold: float):
update_config(config) update_config(config)
mlflow_logging = bool(config.get("mlflow")) mlflow_logging = bool(config.get("mlflow"))
...@@ -94,8 +152,8 @@ def run(config: dict): ...@@ -94,8 +152,8 @@ def run(config: dict):
): ):
mp.spawn( mp.spawn(
eval, eval,
args=(config, mlflow_logging), args=(config, nerval_threshold, mlflow_logging),
nprocs=config["training"]["device"]["nb_gpu"], nprocs=config["training"]["device"]["nb_gpu"],
) )
else: else:
eval(0, config, mlflow_logging) eval(0, config, nerval_threshold, mlflow_logging)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import re import re
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass
from operator import attrgetter from operator import attrgetter
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List
...@@ -23,6 +24,12 @@ REGEX_ONLY_ONE_SPACE = re.compile(r"\s+") ...@@ -23,6 +24,12 @@ REGEX_ONLY_ONE_SPACE = re.compile(r"\s+")
METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"} METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"}
@dataclass
class Inference:
ground_truth: List[str]
prediction: List[str]
class MetricManager: class MetricManager:
def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None): def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None):
self.dataset_name: str = dataset_name self.dataset_name: str = dataset_name
...@@ -40,6 +47,9 @@ class MetricManager: ...@@ -40,6 +47,9 @@ class MetricManager:
self.metric_names: List[str] = metric_names self.metric_names: List[str] = metric_names
self.epoch_metrics = defaultdict(list) self.epoch_metrics = defaultdict(list)
# List of inferences (prediction with their ground truth)
self.inferences = []
def format_string_for_cer(self, text: str, remove_token: bool = False): def format_string_for_cer(self, text: str, remove_token: bool = False):
""" """
Format string for CER computation: remove layout tokens and extra spaces Format string for CER computation: remove layout tokens and extra spaces
...@@ -155,6 +165,8 @@ class MetricManager: ...@@ -155,6 +165,8 @@ class MetricManager:
metrics["time"] = [values["time"]] metrics["time"] = [values["time"]]
gt, prediction = values["str_y"], values["str_x"] gt, prediction = values["str_y"], values["str_x"]
self.inferences.append(Inference(ground_truth=gt, prediction=prediction))
for metric_name in metric_names: for metric_name in metric_names:
match metric_name: match metric_name:
case ( case (
......
...@@ -6,7 +6,7 @@ from copy import deepcopy ...@@ -6,7 +6,7 @@ from copy import deepcopy
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from time import time from time import time
from typing import Dict from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -20,7 +20,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -20,7 +20,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from dan.ocr.manager.metrics import MetricManager from dan.ocr.manager.metrics import Inference, MetricManager
from dan.ocr.manager.ocr import OCRDatasetManager from dan.ocr.manager.ocr import OCRDatasetManager
from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
from dan.ocr.schedulers import DropoutScheduler from dan.ocr.schedulers import DropoutScheduler
...@@ -750,7 +750,7 @@ class GenericTrainingManager: ...@@ -750,7 +750,7 @@ class GenericTrainingManager:
def evaluate( def evaluate(
self, custom_name, sets_list, metric_names, mlflow_logging=False self, custom_name, sets_list, metric_names, mlflow_logging=False
) -> Dict[str, int | float]: ) -> Tuple[Dict[str, int | float], List[Inference]]:
""" """
Main loop for evaluation Main loop for evaluation
""" """
...@@ -810,7 +810,7 @@ class GenericTrainingManager: ...@@ -810,7 +810,7 @@ class GenericTrainingManager:
# Log mlflow artifacts # Log mlflow artifacts
mlflow.log_artifact(path, "predictions") mlflow.log_artifact(path, "predictions")
return metrics return metrics, self.metric_manager[custom_name].inferences
def output_pred(self, name): def output_pred(self, name):
path = self.paths["results"] / "predict_{}_{}.yaml".format( path = self.paths["results"] / "predict_{}_{}.yaml".format(
......
-e ./nerval
albumentations==1.3.1 albumentations==1.3.1
arkindex-export==0.1.9 arkindex-export==0.1.9
boto3==1.26.124 boto3==1.26.124
editdistance==0.6.2
flashlight-text==0.0.4 flashlight-text==0.0.4
imageio==2.26.1 imageio==2.26.1
imagesize==1.4.1 imagesize==1.4.1
...@@ -9,7 +9,6 @@ lxml==4.9.3 ...@@ -9,7 +9,6 @@ lxml==4.9.3
mdutils==1.6.0 mdutils==1.6.0
nltk==3.8.1 nltk==3.8.1
numpy==1.24.3 numpy==1.24.3
prettytable==3.8.0
PyYAML==6.0 PyYAML==6.0
scipy==1.10.1 scipy==1.10.1
sentencepiece==0.1.99 sentencepiece==0.1.99
......
...@@ -3,3 +3,38 @@ ...@@ -3,3 +3,38 @@
| train | 18.89 | 21.05 | 26.67 | 26.67 | 26.67 | 7.14 | | train | 18.89 | 21.05 | 26.67 | 26.67 | 26.67 | 7.14 |
| val | 8.82 | 11.54 | 50.0 | 50.0 | 50.0 | 0.0 | | val | 8.82 | 11.54 | 50.0 | 50.0 | 50.0 | 0.0 |
| test | 2.78 | 3.33 | 14.29 | 14.29 | 14.29 | 0.0 | | test | 2.78 | 3.33 | 14.29 | 14.29 | 14.29 | 0.0 |
train
| tag | predicted | matched | Precision | Recall | F1 | Support |
|:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:|
| Surname | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 |
| Patron | 2 | 0 | 0.0 | 0.0 | 0 | 1 |
| Operai | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 |
| Louche | 2 | 1 | 0.5 | 0.5 | 0.5 | 2 |
| Koala | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 |
| Firstname | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 |
| Chalumeau | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Batiment | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 |
| All | 15 | 12 | 0.8 | 0.857 | 0.828 | 14 |
val
| tag | predicted | matched | Precision | Recall | F1 | Support |
|:---------:|:---------:|:-------:|:---------:|:------:|:----:|:-------:|
| Surname | 1 | 0 | 0.0 | 0.0 | 0 | 1 |
| Patron | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Operai | 1 | 0 | 0.0 | 0.0 | 0 | 1 |
| Louche | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Koala | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Firstname | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Chalumeau | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Batiment | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| All | 8 | 6 | 0.75 | 0.75 | 0.75 | 8 |
test
| tag | predicted | matched | Precision | Recall | F1 | Support |
|:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:|
| Surname | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Louche | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Koala | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Firstname | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| Chalumeau | 1 | 0 | 0.0 | 0.0 | 0 | 1 |
| Batiment | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 |
| All | 6 | 5 | 0.833 | 0.833 | 0.833 | 6 |
...@@ -103,7 +103,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config): ...@@ -103,7 +103,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
# Use the tmp_path as base folder # Use the tmp_path as base folder
evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate" evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"
evaluate.run(evaluate_config) evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD)
# Check that the evaluation results are correct # Check that the evaluation results are correct
for split_name, expected_res in zip( for split_name, expected_res in zip(
...@@ -129,7 +129,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config): ...@@ -129,7 +129,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
# Check the metrics Markdown table # Check the metrics Markdown table
captured_std = capsys.readouterr() captured_std = capsys.readouterr()
last_printed_lines = captured_std.out.split("\n")[-6:] last_printed_lines = captured_std.out.split("\n")[-41:]
assert ( assert (
"\n".join(last_printed_lines) "\n".join(last_printed_lines)
== Path(FIXTURES / "evaluate" / "metrics_table.md").read_text() == Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()
......