From e5bf32a63f6394424e9bfdf19318a31d397d6b93 Mon Sep 17 00:00:00 2001
From: EvaBardou <bardou@teklia.com>
Date: Wed, 3 Jan 2024 17:01:32 +0100
Subject: [PATCH] Use csv.DictReader rather than csv.reader

---
 demo/mapping_file.csv |  1 +
 nerval/evaluate.py    | 16 +++++++++++-----
 2 files changed, 12 insertions(+), 5 deletions(-)

diff --git a/demo/mapping_file.csv b/demo/mapping_file.csv
index 5a2ce92..6b59366 100644
--- a/demo/mapping_file.csv
+++ b/demo/mapping_file.csv
@@ -1,2 +1,3 @@
+Annotation,Prediction
 demo_annot.bio,demo_predict.bio
 toy_test_annot.bio,toy_test_predict.bio
\ No newline at end of file
diff --git a/nerval/evaluate.py b/nerval/evaluate.py
index 5afe7cc..1a438da 100644
--- a/nerval/evaluate.py
+++ b/nerval/evaluate.py
@@ -1,5 +1,5 @@
+import csv
 import logging
-from csv import reader
 from pathlib import Path
 from typing import List
 
@@ -19,6 +19,10 @@ from nerval.utils import print_result_compact, print_results
 
 logger = logging.getLogger(__name__)
 
+ANNO_COLUMN = "Annotation"
+PRED_COLUMN = "Prediction"
+CSV_HEADER = [ANNO_COLUMN, PRED_COLUMN]
+
 
 def compute_matches(
     annotation: str,
@@ -346,7 +350,7 @@ def run_multiple(file_csv: Path, folder: Path, threshold: int, verbose: bool):
     """Run the program for multiple files (correlation indicated in the csv file)"""
     # Read the csv in a list
     with file_csv.open() as read_obj:
-        csv_reader = reader(read_obj)
+        csv_reader = csv.DictReader(read_obj, fieldnames=CSV_HEADER)
         list_cor = list(csv_reader)
 
     if folder.is_dir():
@@ -361,10 +365,10 @@ def run_multiple(file_csv: Path, folder: Path, threshold: int, verbose: bool):
             predict = None
 
             for file in list_bio_file:
-                if row[0] == file.name:
+                if row[ANNO_COLUMN] == file.name:
                     annot = file
             for file in list_bio_file:
-                if row[1] == file.name:
+                if row[PRED_COLUMN] == file.name:
                     predict = file
 
             if annot and predict:
@@ -374,7 +378,9 @@ def run_multiple(file_csv: Path, folder: Path, threshold: int, verbose: bool):
                 recall += scores["All"]["R"]
                 f1 += scores["All"]["F1"]
             else:
-                raise Exception(f"No file found for files {row[0]}, {row[1]}")
+                raise Exception(
+                    f"No file found for files {row[ANNO_COLUMN]}, {row[PRED_COLUMN]}"
+                )
         if count:
             logger.info("Average score on all corpus")
             table = PrettyTable()
-- 
GitLab