From 917410f5d96ddb7f8c70c7d6e8afdb2d0c98b45f Mon Sep 17 00:00:00 2001 From: Eva Bardou <bardou@teklia.com> Date: Wed, 3 Jan 2024 16:28:31 +0000 Subject: [PATCH] Use csv.DictReader rather than csv.reader --- demo/mapping_file.csv | 1 + nerval/evaluate.py | 19 ++++++++++++++----- tests/conftest.py | 5 +++++ tests/fixtures/test_mapping_file.csv | 1 + tests/fixtures/test_mapping_file_error.csv | 3 +++ tests/test_run.py | 7 +++++++ 6 files changed, 31 insertions(+), 5 deletions(-) create mode 100644 tests/fixtures/test_mapping_file_error.csv diff --git a/demo/mapping_file.csv b/demo/mapping_file.csv index 5a2ce92..6b59366 100644 --- a/demo/mapping_file.csv +++ b/demo/mapping_file.csv @@ -1,2 +1,3 @@ +Annotation,Prediction demo_annot.bio,demo_predict.bio toy_test_annot.bio,toy_test_predict.bio \ No newline at end of file diff --git a/nerval/evaluate.py b/nerval/evaluate.py index 26f0118..d25903f 100644 --- a/nerval/evaluate.py +++ b/nerval/evaluate.py @@ -1,5 +1,5 @@ +import csv import logging -from csv import reader from pathlib import Path from typing import List @@ -19,6 +19,10 @@ from nerval.utils import print_result_compact, print_results logger = logging.getLogger(__name__) +ANNO_COLUMN = "Annotation" +PRED_COLUMN = "Prediction" +CSV_HEADER = [ANNO_COLUMN, PRED_COLUMN] + def compute_matches( annotation: str, @@ -346,7 +350,10 @@ def run_multiple(file_csv: Path, folder: Path, threshold: int, verbose: bool): """Run the program for multiple files (correlation indicated in the csv file)""" # Read the csv in a list with file_csv.open() as read_obj: - csv_reader = reader(read_obj) + csv_reader = csv.DictReader(read_obj) + assert ( + csv_reader.fieldnames == CSV_HEADER + ), f'Columns in the CSV mapping should be: {",".join(CSV_HEADER)}' list_cor = list(csv_reader) if not folder.is_dir(): @@ -363,14 +370,16 @@ def run_multiple(file_csv: Path, folder: Path, threshold: int, verbose: bool): predict = None for file in list_bio_file: - if row[0] == file.name: + if row[ANNO_COLUMN] == file.name: annot = file for file in list_bio_file: - if row[1] == file.name: + if row[PRED_COLUMN] == file.name: predict = file if not (annot and predict): - raise Exception(f"No file found for files {row[0]}, {row[1]}") + raise Exception( + f"No file found for files {row[ANNO_COLUMN]}, {row[PRED_COLUMN]}" + ) count += 1 scores = run(annot, predict, threshold, verbose) diff --git a/tests/conftest.py b/tests/conftest.py index 61b7300..2d934c3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,6 +45,11 @@ def folder_bio(): return FIXTURES +@pytest.fixture() +def csv_file_error(): + return FIXTURES / "test_mapping_file_error.csv" + + @pytest.fixture() def csv_file(): return FIXTURES / "test_mapping_file.csv" diff --git a/tests/fixtures/test_mapping_file.csv b/tests/fixtures/test_mapping_file.csv index 5a2ce92..6b59366 100644 --- a/tests/fixtures/test_mapping_file.csv +++ b/tests/fixtures/test_mapping_file.csv @@ -1,2 +1,3 @@ +Annotation,Prediction demo_annot.bio,demo_predict.bio toy_test_annot.bio,toy_test_predict.bio \ No newline at end of file diff --git a/tests/fixtures/test_mapping_file_error.csv b/tests/fixtures/test_mapping_file_error.csv new file mode 100644 index 0000000..f1e3b27 --- /dev/null +++ b/tests/fixtures/test_mapping_file_error.csv @@ -0,0 +1,3 @@ +Anno,Pred +demo_annot.bio,demo_predict.bio +toy_test_annot.bio,toy_test_predict.bio \ No newline at end of file diff --git a/tests/test_run.py b/tests/test_run.py index 63b8ebe..44f4e26 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -114,6 +114,13 @@ def test_run_empty_entry(): evaluate.run(Path("invalid.bio"), Path("invalid.bio"), 0.3, False) +def test_run_invalid_header(csv_file_error, folder_bio): + with pytest.raises( + Exception, match="Columns in the CSV mapping should be: Annotation,Prediction" + ): + evaluate.run_multiple(csv_file_error, folder_bio, 0.3, False) + + def test_run_multiple(csv_file, folder_bio): with pytest.raises( Exception, match="No file found for files demo_annot.bio, demo_predict.bio" -- GitLab