diff --git a/demo/mapping_file.csv b/demo/mapping_file.csv index 5a2ce9244d8751ec5812a297c8b7ddf367ed3a56..6b593660805f8e6d6257c16c015e590d9b068983 100644 --- a/demo/mapping_file.csv +++ b/demo/mapping_file.csv @@ -1,2 +1,3 @@ +Annotation,Prediction demo_annot.bio,demo_predict.bio toy_test_annot.bio,toy_test_predict.bio \ No newline at end of file diff --git a/nerval/evaluate.py b/nerval/evaluate.py index 26f01184480903afd8bdd1d554746be383e03f6b..d25903f08679bc9b8cb902c211300c55b336501f 100644 --- a/nerval/evaluate.py +++ b/nerval/evaluate.py @@ -1,5 +1,5 @@ +import csv import logging -from csv import reader from pathlib import Path from typing import List @@ -19,6 +19,10 @@ from nerval.utils import print_result_compact, print_results logger = logging.getLogger(__name__) +ANNO_COLUMN = "Annotation" +PRED_COLUMN = "Prediction" +CSV_HEADER = [ANNO_COLUMN, PRED_COLUMN] + def compute_matches( annotation: str, @@ -346,7 +350,10 @@ def run_multiple(file_csv: Path, folder: Path, threshold: int, verbose: bool): """Run the program for multiple files (correlation indicated in the csv file)""" # Read the csv in a list with file_csv.open() as read_obj: - csv_reader = reader(read_obj) + csv_reader = csv.DictReader(read_obj) + assert ( + csv_reader.fieldnames == CSV_HEADER + ), f'Columns in the CSV mapping should be: {",".join(CSV_HEADER)}' list_cor = list(csv_reader) if not folder.is_dir(): @@ -363,14 +370,16 @@ def run_multiple(file_csv: Path, folder: Path, threshold: int, verbose: bool): predict = None for file in list_bio_file: - if row[0] == file.name: + if row[ANNO_COLUMN] == file.name: annot = file for file in list_bio_file: - if row[1] == file.name: + if row[PRED_COLUMN] == file.name: predict = file if not (annot and predict): - raise Exception(f"No file found for files {row[0]}, {row[1]}") + raise Exception( + f"No file found for files {row[ANNO_COLUMN]}, {row[PRED_COLUMN]}" + ) count += 1 scores = run(annot, predict, threshold, verbose) diff --git a/tests/conftest.py b/tests/conftest.py index 61b73008682e5ec8321c6130da8cf4c50f4f2616..2d934c3add4f78e9762e6abc35a7acf722e55111 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,6 +45,11 @@ def folder_bio(): return FIXTURES +@pytest.fixture() +def csv_file_error(): + return FIXTURES / "test_mapping_file_error.csv" + + @pytest.fixture() def csv_file(): return FIXTURES / "test_mapping_file.csv" diff --git a/tests/fixtures/test_mapping_file.csv b/tests/fixtures/test_mapping_file.csv index 5a2ce9244d8751ec5812a297c8b7ddf367ed3a56..6b593660805f8e6d6257c16c015e590d9b068983 100644 --- a/tests/fixtures/test_mapping_file.csv +++ b/tests/fixtures/test_mapping_file.csv @@ -1,2 +1,3 @@ +Annotation,Prediction demo_annot.bio,demo_predict.bio toy_test_annot.bio,toy_test_predict.bio \ No newline at end of file diff --git a/tests/fixtures/test_mapping_file_error.csv b/tests/fixtures/test_mapping_file_error.csv new file mode 100644 index 0000000000000000000000000000000000000000..f1e3b278076f00f891dbb38e621b89e55df94340 --- /dev/null +++ b/tests/fixtures/test_mapping_file_error.csv @@ -0,0 +1,3 @@ +Anno,Pred +demo_annot.bio,demo_predict.bio +toy_test_annot.bio,toy_test_predict.bio \ No newline at end of file diff --git a/tests/test_run.py b/tests/test_run.py index 63b8ebe8195a8b0e63d58bd05789d7f7af511a0b..44f4e2693925ccfe56ae18bbb91f2429ab7705f4 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -114,6 +114,13 @@ def test_run_empty_entry(): evaluate.run(Path("invalid.bio"), Path("invalid.bio"), 0.3, False) +def test_run_invalid_header(csv_file_error, folder_bio): + with pytest.raises( + Exception, match="Columns in the CSV mapping should be: Annotation,Prediction" + ): + evaluate.run_multiple(csv_file_error, folder_bio, 0.3, False) + + def test_run_multiple(csv_file, folder_bio): with pytest.raises( Exception, match="No file found for files demo_annot.bio, demo_predict.bio"