diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py index 2ba24f33e5f3c3bb441d4a7af0cb583ddf9c6c6f..1c22b57a91ca56d6b7e46f1b0a75b237a517ef25 100644 --- a/dan/ocr/evaluate.py +++ b/dan/ocr/evaluate.py @@ -110,8 +110,6 @@ def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool): if "ner" not in metric_names: return - print() - def inferences_to_parsed_bio(attr: str): bio_values = [] for inference in inferences: @@ -133,7 +131,7 @@ def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool): continue scores = evaluate(ground_truths, predictions, nerval_threshold) - print(set_name) + print(f"\n#### {set_name}\n") print_results(scores) diff --git a/tests/data/evaluate/metrics_table.md b/tests/data/evaluate/metrics_table.md index b3697dab44a1ae2753ae27285f08ff161dbc3a52..b33a3bf04d2018eef678bbc37841c8753a314366 100644 --- a/tests/data/evaluate/metrics_table.md +++ b/tests/data/evaluate/metrics_table.md @@ -4,7 +4,8 @@ | val | 8.82 | 11.54 | 50.0 | 50.0 | 50.0 | 0.0 | | test | 2.78 | 3.33 | 14.29 | 14.29 | 14.29 | 0.0 | -train +#### train + | tag | predicted | matched | Precision | Recall | F1 | Support | |:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:| | Surname | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 | @@ -16,7 +17,9 @@ train | Chalumeau | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | | Batiment | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 | | All | 15 | 12 | 0.8 | 0.857 | 0.828 | 14 | -val + +#### val + | tag | predicted | matched | Precision | Recall | F1 | Support | |:---------:|:---------:|:-------:|:---------:|:------:|:----:|:-------:| | Surname | 1 | 0 | 0.0 | 0.0 | 0 | 1 | @@ -28,7 +31,9 @@ val | Chalumeau | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | | Batiment | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | | All | 8 | 6 | 0.75 | 0.75 | 0.75 | 8 | -test + +#### test + | tag | predicted | matched | Precision | Recall | F1 | Support | |:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:| | Surname | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 57793ce8bb3e1ebe17cb81d31acf80fdfb0ee797..871c415d2d2aa77634de06b050395d006b2037f9 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -129,7 +129,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config): # Check the metrics Markdown table captured_std = capsys.readouterr() - last_printed_lines = captured_std.out.split("\n")[-41:] + last_printed_lines = captured_std.out.split("\n")[-46:] assert ( "\n".join(last_printed_lines) == Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()