From d9eb90635c2a4c17257933033696691264a4e7a8 Mon Sep 17 00:00:00 2001
From: Charlotte Mauvezin <charlotte.mauvezin@irht.cnrs.fr>
Date: Mon, 3 Jan 2022 16:04:25 +0100
Subject: [PATCH] Rebase

---
 nerval/evaluate.py | 90 +++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 82 insertions(+), 8 deletions(-)

diff --git a/nerval/evaluate.py b/nerval/evaluate.py
index 6601a72..ab11dde 100644
--- a/nerval/evaluate.py
+++ b/nerval/evaluate.py
@@ -1,9 +1,12 @@
 # -*- coding: utf-8 -*-
 
 import argparse
+import glob
 import logging
 import os
 import re
+from csv import reader
+from pathlib import Path
 
 import editdistance
 import edlib
@@ -81,7 +84,11 @@ def parse_bio(path: str) -> dict:
         try:
             word, label = line.split()
         except ValueError:
-            raise (Exception(f"The file {path} given in input is not in BIO format."))
+            raise (
+                Exception(
+                    f"The file {path} given in input is not in BIO format: check line {index} ({line})"
+                )
+            )
 
         # Preserve hyphens to avoid confusion with the hyphens added later during alignment
         word = word.replace("-", "§")
@@ -553,6 +560,37 @@ def run(annotation: str, prediction: str, threshold: int, verbose: bool) -> dict
     return scores
 
 
+def run_multiple(file_csv, folder, threshold, verbose):
+    """Run the program for multiple files (correlation indicated in the csv file)"""
+    # Read the csv in a list
+    with open(file_csv, "r") as read_obj:
+        csv_reader = reader(read_obj)
+        list_cor = list(csv_reader)
+
+    if os.path.isdir(folder):
+        list_bio_file = glob.glob(str(folder) + "/**/*.bio", recursive=True)
+
+        for row in list_cor:
+            annot = None
+            predict = None
+
+            for file in list_bio_file:
+                if row[0] == os.path.basename(file):
+                    annot = file
+            for file in list_bio_file:
+                if row[1] == os.path.basename(file):
+                    predict = file
+
+            if annot and predict:
+                print(os.path.basename(predict))
+                run(annot, predict, threshold, verbose)
+                print()
+            else:
+                raise f"No file found for files {annot}, {predict}"
+    else:
+        raise Exception("the path indicated does not lead to a folder.")
+
+
 def threshold_float_type(arg):
     """Type function for argparse."""
     try:
@@ -571,30 +609,66 @@ def main():
 
     parser = argparse.ArgumentParser(description="Compute score of NER on predict.")
     parser.add_argument(
-        "-a", "--annot", help="Annotation in BIO format.", required=True
+        "-m",
+        "--multiple",
+        help="Single if 1, multiple 2",
+        type=int,
+        required=True,
+    )
+    parser.add_argument(
+        "-a",
+        "--annot",
+        help="Annotation in BIO format.",
     )
     parser.add_argument(
-        "-p", "--predict", help="Prediction in BIO format.", required=True
+        "-p",
+        "--predict",
+        help="Prediction in BIO format.",
     )
     parser.add_argument(
         "-t",
         "--threshold",
         help="Set a distance threshold for the match between gold and predicted entity.",
-        required=False,
         default=THRESHOLD,
         type=threshold_float_type,
     )
+    parser.add_argument(
+        "-c",
+        "--csv",
+        help="Csv with the correlation between the annotation bio files and the predict bio files",
+        type=Path,
+    )
+    parser.add_argument(
+        "-f",
+        "--folder",
+        help="Folder containing the bio files referred to in the csv file",
+        type=Path,
+    )
     parser.add_argument(
         "-v",
         "--verbose",
-        help="Print only the recap if False and detailed results if True (default)",
+        help="Print only the recap if False",
         action="store_false",
-        default="True",
     )
-
     args = parser.parse_args()
 
-    run(args.annot, args.predict, args.threshold, args.verbose)
+    if args.multiple == 1 or args.multiple == 2:
+        if args.multiple == 2:
+            if not args.folder:
+                raise argparse.ArgumentError(args.folder, "-f must be given if -m is 2")
+            if not args.csv:
+                raise argparse.ArgumentError(args.folder, "-c must be given if -m is 2")
+            if args.folder and args.csv:
+                run_multiple(args.csv, args.folder, args.threshold, args.verbose)
+        if args.multiple == 1:
+            if not args.annot:
+                raise argparse.ArgumentError(args.folder, "-a must be given if -m is 1")
+            if not args.predict:
+                raise argparse.ArgumentError(args.folder, "-p must be given if -m is 1")
+            if args.annot and args.predict:
+                run(args.annot, args.predict, args.threshold, args.verbose)
+    else:
+        raise argparse.ArgumentTypeError("Value has to be 1 or 2")
 
 
 if __name__ == "__main__":
-- 
GitLab