diff --git a/nerval/evaluate.py b/nerval/evaluate.py index a2af0bd904f25b2e7d328cb1fb84c8f171a682be..0efe5b4a53cd4c8fabfa9dad884e10e92c6df99f 100644 --- a/nerval/evaluate.py +++ b/nerval/evaluate.py @@ -4,6 +4,7 @@ import logging import os from csv import reader from pathlib import Path +from typing import List import editdistance import edlib @@ -265,22 +266,11 @@ def compute_scores( return scores -def run(annotation: Path, prediction: Path, threshold: int, verbose: bool) -> dict: - """Compute recall and precision for each entity type found in annotation and/or prediction. - - Each measure is given at document level, global score is a micro-average across entity types. - """ - # Get string and list of labels per character - annot = parse_bio(annotation) - predict = parse_bio(prediction) - - if not annot or not predict: - raise Exception("No content found in annotation or prediction files.") - +def evaluate(annotation: dict, prediction: dict, threshold: int) -> dict: # Align annotation and prediction - align_result = edlib.align(annot["words"], predict["words"], task="path") + align_result = edlib.align(annotation["words"], prediction["words"], task="path") nice_alignment = edlib.getNiceAlignment( - align_result, annot["words"], predict["words"] + align_result, annotation["words"], prediction["words"] ) annot_aligned = nice_alignment["query_aligned"] @@ -288,10 +278,10 @@ def run(annotation: Path, prediction: Path, threshold: int, verbose: bool) -> di # Align labels from string alignment labels_annot_aligned = get_labels_aligned( - annot["words"], annot_aligned, annot["labels"] + annotation["words"], annot_aligned, annotation["labels"] ) labels_predict_aligned = get_labels_aligned( - predict["words"], predict_aligned, predict["labels"] + prediction["words"], predict_aligned, prediction["labels"] ) # Get nb match @@ -304,7 +294,33 @@ def run(annotation: Path, prediction: Path, threshold: int, verbose: bool) -> di ) # Compute scores - scores = compute_scores(annot["entity_count"], predict["entity_count"], matches) + scores = compute_scores( + annotation["entity_count"], prediction["entity_count"], matches + ) + return scores + + +def run(annotation: Path, prediction: Path, threshold: int, verbose: bool) -> dict: + """Compute recall and precision for each entity type found in annotation and/or prediction. + + Each measure is given at document level, global score is a micro-average across entity types. + """ + + # Get string and list of labels per character + def read_file(path: Path) -> List[str]: + assert path.exists(), f"Error: Input file {path} does not exist" + return path.read_text().strip().splitlines() + + logger.info(f"Parsing file @ {annotation}") + annot = parse_bio(read_file(annotation)) + + logger.info(f"Parsing file @ {prediction}") + predict = parse_bio(read_file(prediction)) + + if not (annot and predict): + raise Exception("No content found in annotation or prediction files.") + + scores = evaluate(annot, predict, threshold) # Print results if verbose: diff --git a/nerval/parse.py b/nerval/parse.py index edf27155bda3036bb4c23e7c6adf84b6e6cf34b4..c4736fa5dfc687b1c59e27c908ac0dc82060fc8f 100644 --- a/nerval/parse.py +++ b/nerval/parse.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- import re -from pathlib import Path +from typing import List NOT_ENTITY_TAG = "O" BEGINNING_POS = ["B", "S", "U"] @@ -34,12 +34,12 @@ def get_position_label(label: str) -> str: else re.match(r"([BIESLU])-(.*)$", label)[1] ) except TypeError: - raise (Exception(f"The label {label} is not valid in BIOES/BIOLU format.")) + raise Exception(f"The label {label} is not valid in BIOES/BIOLU format.") return pos -def parse_line(index: int, line: str, path: Path): +def parse_line(index: int, line: str): try: match_iob = REGEX_IOB_LINE.search(line) @@ -47,33 +47,24 @@ def parse_line(index: int, line: str, path: Path): return match_iob.group(1, 2) except AssertionError: - raise ( - Exception( - f"The file @ {path} is not in BIO format: check line {index} ({line})" - ) - ) + raise Exception(f"The file is not in BIO format: check line {index} ({line})") -def parse_bio(path: Path) -> dict: +def parse_bio(lines: List[str]) -> dict: """Parse a BIO file to get text content, character-level NE labels and entity types count. - Input : path to a valid BIO file - Output format : { "words": str, "labels": list; "entity_count" : { tag : int } } + Input: lines of a valid BIO file + Output format: { "words": str, "labels": list, "entity_count": { tag: int } } """ - assert path.exists(), f"Error: Input file {path} does not exist" - words = [] labels = [] entity_count = {"All": 0} last_tag = None - with open(path, "r") as fd: - lines = list(filter(lambda x: x != "\n", fd.readlines())) - if "§" in " ".join(lines): raise ( Exception( - f"§ found in input file {path}. Since this character is used in a specific way during evaluation, prease remove it from files." + "§ found in input file. Since this character is used in a specific way during evaluation, prease remove it from files." ) ) @@ -82,7 +73,7 @@ def parse_bio(path: Path) -> dict: containing_tag = None for index, line in enumerate(lines): - word, label = parse_line(index, line, path) + word, label = parse_line(index, line) # Preserve hyphens to avoid confusion with the hyphens added later during alignment word = word.replace("-", "§") diff --git a/tests/test_parse_bio.py b/tests/test_parse_bio.py index 09c3cb589f80ccec8386bd8794e441447299d619..10897953bde5caa825ea0070fd558d6498eb6f18 100644 --- a/tests/test_parse_bio.py +++ b/tests/test_parse_bio.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -from pathlib import Path - import pytest from nerval import evaluate @@ -181,7 +179,8 @@ expected_parsed_end_of_file = { ], ) def test_parse_bio(test_input, expected): - assert evaluate.parse_bio(test_input) == expected + lines = test_input.read_text().strip().splitlines() + assert evaluate.parse_bio(lines) == expected def test_parse_bio_bad_input(bad_bio): @@ -189,11 +188,6 @@ def test_parse_bio_bad_input(bad_bio): evaluate.parse_bio(bad_bio) -def test_parse_bio_no_input(): - with pytest.raises(AssertionError): - evaluate.parse_bio(Path("not_a_bio")) - - @pytest.mark.parametrize( "line, word, label", ( @@ -204,7 +198,7 @@ def test_parse_bio_no_input(): ), ) def test_parse_line(line, word, label): - assert parse_line(index=0, line=line, path=Path("")) == (word, label) + assert parse_line(index=0, line=line) == (word, label) @pytest.mark.parametrize( @@ -213,7 +207,7 @@ def test_parse_line(line, word, label): ) def test_parse_line_crash(line): with pytest.raises(Exception): - parse_line(index=0, line=line, path=Path("")) + parse_line(index=0, line=line) @pytest.mark.parametrize(