Skip to content
Snippets Groups Projects

Evaluate predictions with nerval

Merged Manon Blanco requested to merge nerval-evaluate into main
All threads resolved!
Files
6
+ 28
8
@@ -5,6 +5,7 @@ Evaluate a trained DAN model.
import logging
import random
from argparse import ArgumentTypeError
import numpy as np
import torch
@@ -20,6 +21,20 @@ from nerval.utils import print_results
logger = logging.getLogger(__name__)
NERVAL_THRESHOLD = 0.30
def parse_threshold(arg):
try:
f = float(arg)
except ValueError:
raise ArgumentTypeError("Must be a floating point number.")
if f < 0 or f > 1:
raise ArgumentTypeError("Must be between 0 and 1.")
return f
def add_evaluate_parser(subcommands) -> None:
parser = subcommands.add_parser(
@@ -35,10 +50,17 @@ def add_evaluate_parser(subcommands) -> None:
help="Configuration file.",
)
parser.add_argument(
"--nerval-threshold",
help="Distance threshold for the match between gold and predicted entity during Nerval evaluation.",
default=NERVAL_THRESHOLD,
type=parse_threshold,
)
parser.set_defaults(func=run)
def eval(rank, config, mlflow_logging):
def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
@@ -88,8 +110,6 @@ def eval(rank, config, mlflow_logging):
if "ner" not in metric_names:
return
print()
def inferences_to_parsed_bio(attr: str):
bio_values = []
for inference in inferences:
@@ -110,12 +130,12 @@ def eval(rank, config, mlflow_logging):
if not (ground_truths and predictions):
continue
scores = evaluate(ground_truths, predictions, 0.30)
print(set_name)
scores = evaluate(ground_truths, predictions, nerval_threshold)
print(f"\n#### {set_name}\n")
print_results(scores)
def run(config: dict):
def run(config: dict, nerval_threshold: float):
update_config(config)
mlflow_logging = bool(config.get("mlflow"))
@@ -130,8 +150,8 @@ def run(config: dict):
):
mp.spawn(
eval,
args=(config, mlflow_logging),
args=(config, nerval_threshold, mlflow_logging),
nprocs=config["training"]["device"]["nb_gpu"],
)
else:
eval(0, config, mlflow_logging)
eval(0, config, nerval_threshold, mlflow_logging)
Loading