diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py index 11fdf12d3c1ae5f2b302a863ef7714b85bf0becc..2ba24f33e5f3c3bb441d4a7af0cb583ddf9c6c6f 100644 --- a/dan/ocr/evaluate.py +++ b/dan/ocr/evaluate.py @@ -5,6 +5,7 @@ Evaluate a trained DAN model. import logging import random +from argparse import ArgumentTypeError import numpy as np import torch @@ -20,6 +21,20 @@ from nerval.utils import print_results logger = logging.getLogger(__name__) +NERVAL_THRESHOLD = 0.30 + + +def parse_threshold(arg): + try: + f = float(arg) + except ValueError: + raise ArgumentTypeError("Must be a floating point number.") + + if f < 0 or f > 1: + raise ArgumentTypeError("Must be between 0 and 1.") + + return f + def add_evaluate_parser(subcommands) -> None: parser = subcommands.add_parser( @@ -35,10 +50,17 @@ def add_evaluate_parser(subcommands) -> None: help="Configuration file.", ) + parser.add_argument( + "--nerval-threshold", + help="Distance threshold for the match between gold and predicted entity during Nerval evaluation.", + default=NERVAL_THRESHOLD, + type=parse_threshold, + ) + parser.set_defaults(func=run) -def eval(rank, config, mlflow_logging): +def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool): torch.manual_seed(0) torch.cuda.manual_seed(0) np.random.seed(0) @@ -110,12 +132,12 @@ def eval(rank, config, mlflow_logging): if not (ground_truths and predictions): continue - scores = evaluate(ground_truths, predictions, 0.30) + scores = evaluate(ground_truths, predictions, nerval_threshold) print(set_name) print_results(scores) -def run(config: dict): +def run(config: dict, nerval_threshold: float): update_config(config) mlflow_logging = bool(config.get("mlflow")) @@ -130,8 +152,8 @@ def run(config: dict): ): mp.spawn( eval, - args=(config, mlflow_logging), + args=(config, nerval_threshold, mlflow_logging), nprocs=config["training"]["device"]["nb_gpu"], ) else: - eval(0, config, mlflow_logging) + eval(0, config, nerval_threshold, mlflow_logging) diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 38cec57b7ccfb0ab18d12802938202d9fb65c872..57793ce8bb3e1ebe17cb81d31acf80fdfb0ee797 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -103,7 +103,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config): # Use the tmp_path as base folder evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate" - evaluate.run(evaluate_config) + evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD) # Check that the evaluation results are correct for split_name, expected_res in zip(