From 917410f5d96ddb7f8c70c7d6e8afdb2d0c98b45f Mon Sep 17 00:00:00 2001
From: Eva Bardou <bardou@teklia.com>
Date: Wed, 3 Jan 2024 16:28:31 +0000
Subject: [PATCH] Use csv.DictReader rather than csv.reader

---
 demo/mapping_file.csv                      |  1 +
 nerval/evaluate.py                         | 19 ++++++++++++++-----
 tests/conftest.py                          |  5 +++++
 tests/fixtures/test_mapping_file.csv       |  1 +
 tests/fixtures/test_mapping_file_error.csv |  3 +++
 tests/test_run.py                          |  7 +++++++
 6 files changed, 31 insertions(+), 5 deletions(-)
 create mode 100644 tests/fixtures/test_mapping_file_error.csv

diff --git a/demo/mapping_file.csv b/demo/mapping_file.csv
index 5a2ce92..6b59366 100644
--- a/demo/mapping_file.csv
+++ b/demo/mapping_file.csv
@@ -1,2 +1,3 @@
+Annotation,Prediction
 demo_annot.bio,demo_predict.bio
 toy_test_annot.bio,toy_test_predict.bio
\ No newline at end of file
diff --git a/nerval/evaluate.py b/nerval/evaluate.py
index 26f0118..d25903f 100644
--- a/nerval/evaluate.py
+++ b/nerval/evaluate.py
@@ -1,5 +1,5 @@
+import csv
 import logging
-from csv import reader
 from pathlib import Path
 from typing import List
 
@@ -19,6 +19,10 @@ from nerval.utils import print_result_compact, print_results
 
 logger = logging.getLogger(__name__)
 
+ANNO_COLUMN = "Annotation"
+PRED_COLUMN = "Prediction"
+CSV_HEADER = [ANNO_COLUMN, PRED_COLUMN]
+
 
 def compute_matches(
     annotation: str,
@@ -346,7 +350,10 @@ def run_multiple(file_csv: Path, folder: Path, threshold: int, verbose: bool):
     """Run the program for multiple files (correlation indicated in the csv file)"""
     # Read the csv in a list
     with file_csv.open() as read_obj:
-        csv_reader = reader(read_obj)
+        csv_reader = csv.DictReader(read_obj)
+        assert (
+            csv_reader.fieldnames == CSV_HEADER
+        ), f'Columns in the CSV mapping should be: {",".join(CSV_HEADER)}'
         list_cor = list(csv_reader)
 
     if not folder.is_dir():
@@ -363,14 +370,16 @@ def run_multiple(file_csv: Path, folder: Path, threshold: int, verbose: bool):
         predict = None
 
         for file in list_bio_file:
-            if row[0] == file.name:
+            if row[ANNO_COLUMN] == file.name:
                 annot = file
         for file in list_bio_file:
-            if row[1] == file.name:
+            if row[PRED_COLUMN] == file.name:
                 predict = file
 
         if not (annot and predict):
-            raise Exception(f"No file found for files {row[0]}, {row[1]}")
+            raise Exception(
+                f"No file found for files {row[ANNO_COLUMN]}, {row[PRED_COLUMN]}"
+            )
 
         count += 1
         scores = run(annot, predict, threshold, verbose)
diff --git a/tests/conftest.py b/tests/conftest.py
index 61b7300..2d934c3 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -45,6 +45,11 @@ def folder_bio():
     return FIXTURES
 
 
+@pytest.fixture()
+def csv_file_error():
+    return FIXTURES / "test_mapping_file_error.csv"
+
+
 @pytest.fixture()
 def csv_file():
     return FIXTURES / "test_mapping_file.csv"
diff --git a/tests/fixtures/test_mapping_file.csv b/tests/fixtures/test_mapping_file.csv
index 5a2ce92..6b59366 100644
--- a/tests/fixtures/test_mapping_file.csv
+++ b/tests/fixtures/test_mapping_file.csv
@@ -1,2 +1,3 @@
+Annotation,Prediction
 demo_annot.bio,demo_predict.bio
 toy_test_annot.bio,toy_test_predict.bio
\ No newline at end of file
diff --git a/tests/fixtures/test_mapping_file_error.csv b/tests/fixtures/test_mapping_file_error.csv
new file mode 100644
index 0000000..f1e3b27
--- /dev/null
+++ b/tests/fixtures/test_mapping_file_error.csv
@@ -0,0 +1,3 @@
+Anno,Pred
+demo_annot.bio,demo_predict.bio
+toy_test_annot.bio,toy_test_predict.bio
\ No newline at end of file
diff --git a/tests/test_run.py b/tests/test_run.py
index 63b8ebe..44f4e26 100644
--- a/tests/test_run.py
+++ b/tests/test_run.py
@@ -114,6 +114,13 @@ def test_run_empty_entry():
         evaluate.run(Path("invalid.bio"), Path("invalid.bio"), 0.3, False)
 
 
+def test_run_invalid_header(csv_file_error, folder_bio):
+    with pytest.raises(
+        Exception, match="Columns in the CSV mapping should be: Annotation,Prediction"
+    ):
+        evaluate.run_multiple(csv_file_error, folder_bio, 0.3, False)
+
+
 def test_run_multiple(csv_file, folder_bio):
     with pytest.raises(
         Exception, match="No file found for files demo_annot.bio, demo_predict.bio"
-- 
GitLab