From e5bf32a63f6394424e9bfdf19318a31d397d6b93 Mon Sep 17 00:00:00 2001 From: EvaBardou <bardou@teklia.com> Date: Wed, 3 Jan 2024 17:01:32 +0100 Subject: [PATCH] Use csv.DictReader rather than csv.reader --- demo/mapping_file.csv | 1 + nerval/evaluate.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/demo/mapping_file.csv b/demo/mapping_file.csv index 5a2ce92..6b59366 100644 --- a/demo/mapping_file.csv +++ b/demo/mapping_file.csv @@ -1,2 +1,3 @@ +Annotation,Prediction demo_annot.bio,demo_predict.bio toy_test_annot.bio,toy_test_predict.bio \ No newline at end of file diff --git a/nerval/evaluate.py b/nerval/evaluate.py index 5afe7cc..1a438da 100644 --- a/nerval/evaluate.py +++ b/nerval/evaluate.py @@ -1,5 +1,5 @@ +import csv import logging -from csv import reader from pathlib import Path from typing import List @@ -19,6 +19,10 @@ from nerval.utils import print_result_compact, print_results logger = logging.getLogger(__name__) +ANNO_COLUMN = "Annotation" +PRED_COLUMN = "Prediction" +CSV_HEADER = [ANNO_COLUMN, PRED_COLUMN] + def compute_matches( annotation: str, @@ -346,7 +350,7 @@ def run_multiple(file_csv: Path, folder: Path, threshold: int, verbose: bool): """Run the program for multiple files (correlation indicated in the csv file)""" # Read the csv in a list with file_csv.open() as read_obj: - csv_reader = reader(read_obj) + csv_reader = csv.DictReader(read_obj, fieldnames=CSV_HEADER) list_cor = list(csv_reader) if folder.is_dir(): @@ -361,10 +365,10 @@ def run_multiple(file_csv: Path, folder: Path, threshold: int, verbose: bool): predict = None for file in list_bio_file: - if row[0] == file.name: + if row[ANNO_COLUMN] == file.name: annot = file for file in list_bio_file: - if row[1] == file.name: + if row[PRED_COLUMN] == file.name: predict = file if annot and predict: @@ -374,7 +378,9 @@ def run_multiple(file_csv: Path, folder: Path, threshold: int, verbose: bool): recall += scores["All"]["R"] f1 += scores["All"]["F1"] else: - raise Exception(f"No file found for files {row[0]}, {row[1]}") + raise Exception( + f"No file found for files {row[ANNO_COLUMN]}, {row[PRED_COLUMN]}" + ) if count: logger.info("Average score on all corpus") table = PrettyTable() -- GitLab