Skip to content
Snippets Groups Projects
Commit 0658f885 authored by Manon Blanco's avatar Manon Blanco Committed by Yoann Schneider
Browse files

Expose parsing/evaluation code

parent ad29708f
No related branches found
No related tags found
1 merge request!29Expose parsing/evaluation code
Pipeline #143876 passed
......@@ -4,6 +4,7 @@ import logging
import os
from csv import reader
from pathlib import Path
from typing import List
import editdistance
import edlib
......@@ -265,22 +266,11 @@ def compute_scores(
return scores
def run(annotation: Path, prediction: Path, threshold: int, verbose: bool) -> dict:
"""Compute recall and precision for each entity type found in annotation and/or prediction.
Each measure is given at document level, global score is a micro-average across entity types.
"""
# Get string and list of labels per character
annot = parse_bio(annotation)
predict = parse_bio(prediction)
if not annot or not predict:
raise Exception("No content found in annotation or prediction files.")
def evaluate(annotation: dict, prediction: dict, threshold: int) -> dict:
# Align annotation and prediction
align_result = edlib.align(annot["words"], predict["words"], task="path")
align_result = edlib.align(annotation["words"], prediction["words"], task="path")
nice_alignment = edlib.getNiceAlignment(
align_result, annot["words"], predict["words"]
align_result, annotation["words"], prediction["words"]
)
annot_aligned = nice_alignment["query_aligned"]
......@@ -288,10 +278,10 @@ def run(annotation: Path, prediction: Path, threshold: int, verbose: bool) -> di
# Align labels from string alignment
labels_annot_aligned = get_labels_aligned(
annot["words"], annot_aligned, annot["labels"]
annotation["words"], annot_aligned, annotation["labels"]
)
labels_predict_aligned = get_labels_aligned(
predict["words"], predict_aligned, predict["labels"]
prediction["words"], predict_aligned, prediction["labels"]
)
# Get nb match
......@@ -304,7 +294,33 @@ def run(annotation: Path, prediction: Path, threshold: int, verbose: bool) -> di
)
# Compute scores
scores = compute_scores(annot["entity_count"], predict["entity_count"], matches)
scores = compute_scores(
annotation["entity_count"], prediction["entity_count"], matches
)
return scores
def run(annotation: Path, prediction: Path, threshold: int, verbose: bool) -> dict:
"""Compute recall and precision for each entity type found in annotation and/or prediction.
Each measure is given at document level, global score is a micro-average across entity types.
"""
# Get string and list of labels per character
def read_file(path: Path) -> List[str]:
assert path.exists(), f"Error: Input file {path} does not exist"
return path.read_text().strip().splitlines()
logger.info(f"Parsing file @ {annotation}")
annot = parse_bio(read_file(annotation))
logger.info(f"Parsing file @ {prediction}")
predict = parse_bio(read_file(prediction))
if not (annot and predict):
raise Exception("No content found in annotation or prediction files.")
scores = evaluate(annot, predict, threshold)
# Print results
if verbose:
......
# -*- coding: utf-8 -*-
import re
from pathlib import Path
from typing import List
NOT_ENTITY_TAG = "O"
BEGINNING_POS = ["B", "S", "U"]
......@@ -34,12 +34,12 @@ def get_position_label(label: str) -> str:
else re.match(r"([BIESLU])-(.*)$", label)[1]
)
except TypeError:
raise (Exception(f"The label {label} is not valid in BIOES/BIOLU format."))
raise Exception(f"The label {label} is not valid in BIOES/BIOLU format.")
return pos
def parse_line(index: int, line: str, path: Path):
def parse_line(index: int, line: str):
try:
match_iob = REGEX_IOB_LINE.search(line)
......@@ -47,33 +47,24 @@ def parse_line(index: int, line: str, path: Path):
return match_iob.group(1, 2)
except AssertionError:
raise (
Exception(
f"The file @ {path} is not in BIO format: check line {index} ({line})"
)
)
raise Exception(f"The file is not in BIO format: check line {index} ({line})")
def parse_bio(path: Path) -> dict:
def parse_bio(lines: List[str]) -> dict:
"""Parse a BIO file to get text content, character-level NE labels and entity types count.
Input : path to a valid BIO file
Output format : { "words": str, "labels": list; "entity_count" : { tag : int } }
Input: lines of a valid BIO file
Output format: { "words": str, "labels": list, "entity_count": { tag: int } }
"""
assert path.exists(), f"Error: Input file {path} does not exist"
words = []
labels = []
entity_count = {"All": 0}
last_tag = None
with open(path, "r") as fd:
lines = list(filter(lambda x: x != "\n", fd.readlines()))
if "§" in " ".join(lines):
raise (
Exception(
f"§ found in input file {path}. Since this character is used in a specific way during evaluation, prease remove it from files."
"§ found in input file. Since this character is used in a specific way during evaluation, prease remove it from files."
)
)
......@@ -82,7 +73,7 @@ def parse_bio(path: Path) -> dict:
containing_tag = None
for index, line in enumerate(lines):
word, label = parse_line(index, line, path)
word, label = parse_line(index, line)
# Preserve hyphens to avoid confusion with the hyphens added later during alignment
word = word.replace("-", "§")
......
# -*- coding: utf-8 -*-
from pathlib import Path
import pytest
from nerval import evaluate
......@@ -181,7 +179,8 @@ expected_parsed_end_of_file = {
],
)
def test_parse_bio(test_input, expected):
assert evaluate.parse_bio(test_input) == expected
lines = test_input.read_text().strip().splitlines()
assert evaluate.parse_bio(lines) == expected
def test_parse_bio_bad_input(bad_bio):
......@@ -189,11 +188,6 @@ def test_parse_bio_bad_input(bad_bio):
evaluate.parse_bio(bad_bio)
def test_parse_bio_no_input():
with pytest.raises(AssertionError):
evaluate.parse_bio(Path("not_a_bio"))
@pytest.mark.parametrize(
"line, word, label",
(
......@@ -204,7 +198,7 @@ def test_parse_bio_no_input():
),
)
def test_parse_line(line, word, label):
assert parse_line(index=0, line=line, path=Path("")) == (word, label)
assert parse_line(index=0, line=line) == (word, label)
@pytest.mark.parametrize(
......@@ -213,7 +207,7 @@ def test_parse_line(line, word, label):
)
def test_parse_line_crash(line):
with pytest.raises(Exception):
parse_line(index=0, line=line, path=Path(""))
parse_line(index=0, line=line)
@pytest.mark.parametrize(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment