Skip to content
Snippets Groups Projects
Commit 9ada05d2 authored by Manon Blanco's avatar Manon Blanco
Browse files

Expose Nerval threshold

parent d5b1e9a2
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@ Evaluate a trained DAN model.
import logging
import random
from argparse import ArgumentTypeError
import numpy as np
import torch
......@@ -20,6 +21,20 @@ from nerval.utils import print_results
logger = logging.getLogger(__name__)
NERVAL_THRESHOLD = 0.30
def parse_threshold(arg):
try:
f = float(arg)
except ValueError:
raise ArgumentTypeError("Must be a floating point number.")
if f < 0 or f > 1:
raise ArgumentTypeError("Must be between 0 and 1.")
return f
def add_evaluate_parser(subcommands) -> None:
parser = subcommands.add_parser(
......@@ -35,10 +50,17 @@ def add_evaluate_parser(subcommands) -> None:
help="Configuration file.",
)
parser.add_argument(
"--nerval-threshold",
help="Distance threshold for the match between gold and predicted entity during Nerval evaluation.",
default=NERVAL_THRESHOLD,
type=parse_threshold,
)
parser.set_defaults(func=run)
def eval(rank, config, mlflow_logging):
def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
......@@ -110,12 +132,12 @@ def eval(rank, config, mlflow_logging):
if not (ground_truths and predictions):
continue
scores = evaluate(ground_truths, predictions, 0.30)
scores = evaluate(ground_truths, predictions, nerval_threshold)
print(set_name)
print_results(scores)
def run(config: dict):
def run(config: dict, nerval_threshold: float):
update_config(config)
mlflow_logging = bool(config.get("mlflow"))
......@@ -130,8 +152,8 @@ def run(config: dict):
):
mp.spawn(
eval,
args=(config, mlflow_logging),
args=(config, nerval_threshold, mlflow_logging),
nprocs=config["training"]["device"]["nb_gpu"],
)
else:
eval(0, config, mlflow_logging)
eval(0, config, nerval_threshold, mlflow_logging)
......@@ -103,7 +103,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config):
# Use the tmp_path as base folder
evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate"
evaluate.run(evaluate_config)
evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD)
# Check that the evaluation results are correct
for split_name, expected_res in zip(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment