From 2ac2b607301c72b555ef2817b1b8b0b676f2fe39 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Wed, 1 Mar 2023 22:27:43 +0000 Subject: [PATCH] Some small refactoring --- .gitlab-ci.yml | 5 +- .pre-commit-config.yaml | 15 +- MANIFEST.in | 3 +- nerval/__init__.py | 8 + nerval/cli.py | 87 +++++++ nerval/evaluate.py | 346 +------------------------- nerval/parse.py | 203 +++++++++++++++ nerval/utils.py | 44 ++++ setup.py | 38 ++- tests/__init__.py | 0 tests/conftest.py | 51 ++++ tests/{ => fixtures}/bioeslu.bio | 0 tests/{ => fixtures}/end_of_file.bio | 0 tests/{ => fixtures}/test_annot.bio | 0 tests/{ => fixtures}/test_bad.bio | 0 tests/{ => fixtures}/test_empty.bio | 0 tests/{ => fixtures}/test_nested.bio | 0 tests/{ => fixtures}/test_predict.bio | 0 tests/test_align.py | 30 +-- tests/test_compute_matches.py | 46 ++-- tests/test_compute_scores.py | 85 ++++--- tests/test_parse_bio.py | 27 +- tests/test_run.py | 167 +++++++------ tox.ini | 4 +- 24 files changed, 626 insertions(+), 533 deletions(-) create mode 100644 nerval/cli.py create mode 100644 nerval/parse.py create mode 100644 nerval/utils.py delete mode 100644 tests/__init__.py create mode 100644 tests/conftest.py rename tests/{ => fixtures}/bioeslu.bio (100%) rename tests/{ => fixtures}/end_of_file.bio (100%) rename tests/{ => fixtures}/test_annot.bio (100%) rename tests/{ => fixtures}/test_bad.bio (100%) rename tests/{ => fixtures}/test_empty.bio (100%) rename tests/{ => fixtures}/test_nested.bio (100%) rename tests/{ => fixtures}/test_predict.bio (100%) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 93558bf..63f5e7e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,6 +1,5 @@ stages: - test - - build - release variables: @@ -11,7 +10,7 @@ cache: linter: stage: test - image: python:3.8 + image: python:3 cache: paths: @@ -32,7 +31,7 @@ linter: tests: stage: test - image: python:3.8 + image: python:3 cache: paths: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ca8367e..9a3c7b8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,22 +1,21 @@ repos: - - repo: https://github.com/pre-commit/mirrors-isort - rev: v5.10.1 + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 hooks: - id: isort - repo: https://github.com/ambv/black - rev: 22.10.0 + rev: 23.1.0 hooks: - id: black - repo: https://github.com/pycqa/flake8 - rev: 3.9.2 + rev: 6.0.0 hooks: - id: flake8 additional_dependencies: - - 'flake8-coding==1.3.1' - - 'flake8-copyright==0.2.2' - - 'flake8-debugger==3.1.0' + - 'flake8-coding==1.3.2' + - 'flake8-debugger==4.1.2' - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: check-ast - id: check-docstring-first diff --git a/MANIFEST.in b/MANIFEST.in index bb3ec5f..fd959fa 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ -include README.md +include requirements.txt +include VERSION diff --git a/nerval/__init__.py b/nerval/__init__.py index e69de29..b74e888 100644 --- a/nerval/__init__.py +++ b/nerval/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s/%(name)s: %(message)s", +) +logger = logging.getLogger(__name__) diff --git a/nerval/cli.py b/nerval/cli.py new file mode 100644 index 0000000..8604b7f --- /dev/null +++ b/nerval/cli.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +import argparse +from pathlib import Path + +from nerval.evaluate import run, run_multiple + +THRESHOLD = 0.30 + + +def threshold_float_type(arg): + """Type function for argparse.""" + try: + f = float(arg) + except ValueError: + raise argparse.ArgumentTypeError("Must be a floating point number.") + if f < 0 or f > 1: + raise argparse.ArgumentTypeError("Must be between 0 and 1.") + return f + + +def parse_args(): + """Get arguments and run.""" + parser = argparse.ArgumentParser(description="Compute score of NER on predict.") + + group = parser.add_mutually_exclusive_group() + group.add_argument( + "-a", + "--annot", + help="Annotation in BIO format.", + ) + group.add_argument( + "-c", + "--csv", + help="CSV with the correlation between the annotation bio files and the predict bio files", + type=Path, + ) + parser.add_argument( + "-p", + "--predict", + help="Prediction in BIO format.", + ) + parser.add_argument( + "-f", + "--folder", + help="Folder containing the bio files referred to in the csv file", + type=Path, + ) + parser.add_argument( + "-v", + "--verbose", + help="Print only the recap if False", + action="store_false", + ) + parser.add_argument( + "-t", + "--threshold", + help="Set a distance threshold for the match between gold and predicted entity.", + default=THRESHOLD, + type=threshold_float_type, + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + + if args.annot: + if not args.predict: + raise argparse.ArgumentTypeError( + "You need to specify the path to a predict file with -p" + ) + run(args.annot, args.predict, args.threshold, args.verbose) + elif args.csv: + if not args.folder: + raise argparse.ArgumentTypeError( + "You need to specify the path to a folder of bio files with -f" + ) + run_multiple(args.csv, args.folder, args.threshold, args.verbose) + else: + raise argparse.ArgumentTypeError( + "You need to specify the argument of input file" + ) + + +if __name__ == "__main__": + main() diff --git a/nerval/evaluate.py b/nerval/evaluate.py index 80335e9..5cbe3ab 100644 --- a/nerval/evaluate.py +++ b/nerval/evaluate.py @@ -1,10 +1,8 @@ # -*- coding: utf-8 -*- -import argparse import glob import logging import os -import re from csv import reader from pathlib import Path @@ -12,211 +10,17 @@ import editdistance import edlib import termtables as tt -NOT_ENTITY_TAG = "O" +from nerval.parse import ( + BEGINNING_POS, + NOT_ENTITY_TAG, + get_position_label, + get_type_label, + look_for_further_entity_part, + parse_bio, +) +from nerval.utils import print_result_compact, print_results -THRESHOLD = 0.30 -BEGINNING_POS = ["B", "S", "U"] - - -def get_type_label(label: str) -> str: - """Return the type (tag) of a label - - Input format: "[BIESLU]-type" - """ - try: - tag = ( - NOT_ENTITY_TAG - if label == NOT_ENTITY_TAG - else re.match(r"[BIESLU]-(.*)$", label)[1] - ) - except TypeError: - raise (Exception(f"The label {label} is not valid in BIOES/BIOLU format.")) - - return tag - - -def get_position_label(label: str) -> str: - """Return the position of a label - - Input format: "[BIESLU]-type" - """ - try: - pos = ( - NOT_ENTITY_TAG - if label == NOT_ENTITY_TAG - else re.match(r"([BIESLU])-(.*)$", label)[1] - ) - except TypeError: - raise (Exception(f"The label {label} is not valid in BIOES/BIOLU format.")) - - return pos - - -def parse_bio(path: str) -> dict: - """Parse a BIO file to get text content, character-level NE labels and entity types count. - - Input : path to a valid BIO file - Output format : { "words": str, "labels": list; "entity_count" : { tag : int } } - """ - assert os.path.exists(path), f"Error: Input file {path} does not exist" - - words = [] - labels = [] - entity_count = {"All": 0} - last_tag = None - - with open(path, "r") as fd: - lines = list(filter(lambda x: x != "\n", fd.readlines())) - - if "§" in " ".join(lines): - raise ( - Exception( - f"§ found in input file {path}. Since this character is used in a specific way during evaluation, prease remove it from files." - ) - ) - - # Track nested entities infos - in_nested_entity = False - containing_tag = None - - for index, line in enumerate(lines): - - try: - word, label = line.split() - except ValueError: - raise ( - Exception( - f"The file {path} given in input is not in BIO format: check line {index} ({line})" - ) - ) - - # Preserve hyphens to avoid confusion with the hyphens added later during alignment - word = word.replace("-", "§") - words.append(word) - - tag = get_type_label(label) - - # Spaces will be added between words and have to get a label - if index != 0: - - # If new word has same tag as previous, not new entity and in entity, continue entity - if ( - last_tag == tag - and get_position_label(label) not in BEGINNING_POS - and tag != NOT_ENTITY_TAG - ): - labels.append(f"I-{last_tag}") - - # If new word begins a new entity of different type, check for nested entity to correctly tag the space - elif ( - last_tag != tag - and get_position_label(label) in BEGINNING_POS - and tag != NOT_ENTITY_TAG - and last_tag != NOT_ENTITY_TAG - ): - - # Advance to next word with different label as current - future_label = label - while ( - index < len(lines) - and future_label != NOT_ENTITY_TAG - and get_type_label(future_label) != last_tag - ): - index += 1 - if index < len(lines): - future_label = lines[index].split()[1] - - # Check for continuation of the original entity - if ( - index < len(lines) - and get_position_label(future_label) not in BEGINNING_POS - and get_type_label(future_label) == last_tag - ): - labels.append(f"I-{last_tag}") - in_nested_entity = True - containing_tag = last_tag - else: - labels.append(NOT_ENTITY_TAG) - in_nested_entity = False - - elif in_nested_entity: - labels.append(f"I-{containing_tag}") - - else: - labels.append(NOT_ENTITY_TAG) - in_nested_entity = False - - # Add a tag for each letter in the word - if get_position_label(label) in BEGINNING_POS: - labels += [f"B-{tag}"] + [f"I-{tag}"] * (len(word) - 1) - else: - labels += [label] * len(word) - - # Count nb entity for each type - if get_position_label(label) in BEGINNING_POS: - entity_count[tag] = entity_count.get(tag, 0) + 1 - entity_count["All"] += 1 - - last_tag = tag - - result = None - - if words: - - result = dict() - result["words"] = " ".join(words) - result["labels"] = labels - result["entity_count"] = entity_count - - assert len(result["words"]) == len(result["labels"]) - for tag in result["entity_count"]: - if tag != "All": - assert result["labels"].count(f"B-{tag}") == result["entity_count"][tag] - - return result - - -def look_for_further_entity_part(index, tag, characters, labels): - """Get further entities parts for long entities with nested entities. - - Input: - index: the starting index to look for rest of entity (one after last character included) - tag: the type of the entity investigated - characters: the string of the annotation or prediction - the labels associated with characters - Output : - complete string of the rest of the entity found - visited: indexes of the characters used for this last entity part OF THE DESIGNATED TAG. Do not process again later - """ - original_index = index - last_loop_index = index - research = True - visited = [] - while research: - while ( - index < len(characters) - and labels[index] != NOT_ENTITY_TAG - and get_type_label(labels[index]) != tag - ): - index += 1 - while ( - index < len(characters) - and get_position_label(labels[index]) not in BEGINNING_POS - and get_type_label(labels[index]) == tag - ): - visited.append(index) - index += 1 - - research = index != last_loop_index and get_type_label(labels[index - 1]) == tag - last_loop_index = index - - characters_to_add = ( - characters[original_index:index] - if get_type_label(labels[index - 1]) == tag - else [] - ) - - return characters_to_add, visited +logger = logging.getLogger(__name__) def compute_matches( @@ -268,7 +72,6 @@ def compute_matches( # Iterating on reference string for i, char_annot in enumerate(annotation): - if i in visited_annot: continue @@ -282,7 +85,6 @@ def compute_matches( last_tag = NOT_ENTITY_TAG else: - # If beginning new entity if get_position_label(label_ref) in BEGINNING_POS: current_ref, current_compar = [], [] @@ -294,7 +96,6 @@ def compute_matches( # Searching character string corresponding with tag if not found_aligned_end and tag_predict == tag_ref: - if i in visited_predict: continue @@ -335,7 +136,6 @@ def compute_matches( i + 1 < len(annotation) and get_type_label(labels_annot[i + 1]) != last_tag ): - if not found_aligned_end: rest_predict, visited = look_for_further_entity_part( i + 1, tag_ref, prediction, labels_predict @@ -466,50 +266,7 @@ def compute_scores( return scores -def print_results(scores: dict): - """Display final results. - - None values are kept to indicate the absence of a certain tag in either annotation or prediction. - """ - header = ["tag", "predicted", "matched", "Precision", "Recall", "F1", "Support"] - results = [] - for tag in sorted(scores.keys())[::-1]: - prec = None if scores[tag]["P"] is None else round(scores[tag]["P"], 3) - rec = None if scores[tag]["R"] is None else round(scores[tag]["R"], 3) - f1 = None if scores[tag]["F1"] is None else round(scores[tag]["F1"], 3) - - results.append( - [ - tag, - scores[tag]["predicted"], - scores[tag]["matched"], - prec, - rec, - f1, - scores[tag]["Support"], - ] - ) - tt.print(results, header, style=tt.styles.markdown) - - -def print_result_compact(scores: dict): - result = [] - header = ["tag", "predicted", "matched", "Precision", "Recall", "F1", "Support"] - result.append( - [ - "ALl", - scores["All"]["predicted"], - scores["All"]["matched"], - round(scores["All"]["P"], 3), - round(scores["All"]["R"], 3), - round(scores["All"]["F1"], 3), - scores["All"]["Support"], - ] - ) - tt.print(result, header, style=tt.styles.markdown) - - -def run(annotation: str, prediction: str, threshold: int, verbose: bool) -> dict: +def run(annotation: Path, prediction: Path, threshold: int, verbose: bool) -> dict: """Compute recall and precision for each entity type found in annotation and/or prediction. Each measure is given at document level, global score is a micro-average across entity types. @@ -586,16 +343,14 @@ def run_multiple(file_csv, folder, threshold, verbose): if annot and predict: count += 1 - print(os.path.basename(predict)) scores = run(annot, predict, threshold, verbose) precision += scores["All"]["P"] recall += scores["All"]["R"] f1 += scores["All"]["F1"] - print() else: raise Exception(f"No file found for files {annot}, {predict}") if count: - print("Average score on all corpus") + logger.info("Average score on all corpus") tt.print( [ [ @@ -611,80 +366,3 @@ def run_multiple(file_csv, folder, threshold, verbose): raise Exception("No file were counted") else: raise Exception("the path indicated does not lead to a folder.") - - -def threshold_float_type(arg): - """Type function for argparse.""" - try: - f = float(arg) - except ValueError: - raise argparse.ArgumentTypeError("Must be a floating point number.") - if f < 0 or f > 1: - raise argparse.ArgumentTypeError("Must be between 0 and 1.") - return f - - -def main(): - """Get arguments and run.""" - - logging.basicConfig(level=logging.INFO) - - parser = argparse.ArgumentParser(description="Compute score of NER on predict.") - - group = parser.add_mutually_exclusive_group() - group.add_argument( - "-a", - "--annot", - help="Annotation in BIO format.", - ) - group.add_argument( - "-c", - "--csv", - help="Csv with the correlation between the annotation bio files and the predict bio files", - type=Path, - ) - parser.add_argument( - "-p", - "--predict", - help="Prediction in BIO format.", - ) - parser.add_argument( - "-f", - "--folder", - help="Folder containing the bio files referred to in the csv file", - type=Path, - ) - parser.add_argument( - "-v", - "--verbose", - help="Print only the recap if False", - action="store_false", - ) - parser.add_argument( - "-t", - "--threshold", - help="Set a distance threshold for the match between gold and predicted entity.", - default=THRESHOLD, - type=threshold_float_type, - ) - - args = parser.parse_args() - - if args.annot: - if not args.predict: - raise parser.error("You need to specify the path to a predict file with -p") - if args.annot and args.predict: - run(args.annot, args.predict, args.threshold, args.verbose) - elif args.csv: - if not args.folder: - raise parser.error( - "You need to specify the path to a folder of bio files with -f" - ) - if args.folder and args.csv: - run_multiple(args.csv, args.folder, args.threshold, args.verbose) - else: - raise parser.error("You need to specify the argument of input file") - - -if __name__ == "__main__": - main() diff --git a/nerval/parse.py b/nerval/parse.py new file mode 100644 index 0000000..60c50fd --- /dev/null +++ b/nerval/parse.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +import re +from pathlib import Path + +NOT_ENTITY_TAG = "O" +BEGINNING_POS = ["B", "S", "U"] + + +def get_type_label(label: str) -> str: + """Return the type (tag) of a label + + Input format: "[BIESLU]-type" + """ + try: + tag = ( + NOT_ENTITY_TAG + if label == NOT_ENTITY_TAG + else re.match(r"[BIESLU]-(.*)$", label)[1] + ) + except TypeError: + raise (Exception(f"The label {label} is not valid in BIOES/BIOLU format.")) + + return tag + + +def get_position_label(label: str) -> str: + """Return the position of a label + + Input format: "[BIESLU]-type" + """ + try: + pos = ( + NOT_ENTITY_TAG + if label == NOT_ENTITY_TAG + else re.match(r"([BIESLU])-(.*)$", label)[1] + ) + except TypeError: + raise (Exception(f"The label {label} is not valid in BIOES/BIOLU format.")) + + return pos + + +def parse_bio(path: Path) -> dict: + """Parse a BIO file to get text content, character-level NE labels and entity types count. + + Input : path to a valid BIO file + Output format : { "words": str, "labels": list; "entity_count" : { tag : int } } + """ + assert path.exists(), f"Error: Input file {path} does not exist" + + words = [] + labels = [] + entity_count = {"All": 0} + last_tag = None + + with open(path, "r") as fd: + lines = list(filter(lambda x: x != "\n", fd.readlines())) + + if "§" in " ".join(lines): + raise ( + Exception( + f"§ found in input file {path}. Since this character is used in a specific way during evaluation, prease remove it from files." + ) + ) + + # Track nested entities infos + in_nested_entity = False + containing_tag = None + + for index, line in enumerate(lines): + try: + word, label = line.split() + except ValueError: + raise ( + Exception( + f"The file {path} given in input is not in BIO format: check line {index} ({line})" + ) + ) + + # Preserve hyphens to avoid confusion with the hyphens added later during alignment + word = word.replace("-", "§") + words.append(word) + + tag = get_type_label(label) + + # Spaces will be added between words and have to get a label + if index != 0: + # If new word has same tag as previous, not new entity and in entity, continue entity + if ( + last_tag == tag + and get_position_label(label) not in BEGINNING_POS + and tag != NOT_ENTITY_TAG + ): + labels.append(f"I-{last_tag}") + + # If new word begins a new entity of different type, check for nested entity to correctly tag the space + elif ( + last_tag != tag + and get_position_label(label) in BEGINNING_POS + and tag != NOT_ENTITY_TAG + and last_tag != NOT_ENTITY_TAG + ): + # Advance to next word with different label as current + future_label = label + while ( + index < len(lines) + and future_label != NOT_ENTITY_TAG + and get_type_label(future_label) != last_tag + ): + index += 1 + if index < len(lines): + future_label = lines[index].split()[1] + + # Check for continuation of the original entity + if ( + index < len(lines) + and get_position_label(future_label) not in BEGINNING_POS + and get_type_label(future_label) == last_tag + ): + labels.append(f"I-{last_tag}") + in_nested_entity = True + containing_tag = last_tag + else: + labels.append(NOT_ENTITY_TAG) + in_nested_entity = False + + elif in_nested_entity: + labels.append(f"I-{containing_tag}") + + else: + labels.append(NOT_ENTITY_TAG) + in_nested_entity = False + + # Add a tag for each letter in the word + if get_position_label(label) in BEGINNING_POS: + labels += [f"B-{tag}"] + [f"I-{tag}"] * (len(word) - 1) + else: + labels += [label] * len(word) + + # Count nb entity for each type + if get_position_label(label) in BEGINNING_POS: + entity_count[tag] = entity_count.get(tag, 0) + 1 + entity_count["All"] += 1 + + last_tag = tag + + result = None + + if words: + result = dict() + result["words"] = " ".join(words) + result["labels"] = labels + result["entity_count"] = entity_count + + assert len(result["words"]) == len(result["labels"]) + for tag in result["entity_count"]: + if tag != "All": + assert result["labels"].count(f"B-{tag}") == result["entity_count"][tag] + + return result + + +def look_for_further_entity_part(index, tag, characters, labels): + """Get further entities parts for long entities with nested entities. + + Input: + index: the starting index to look for rest of entity (one after last character included) + tag: the type of the entity investigated + characters: the string of the annotation or prediction + the labels associated with characters + Output : + complete string of the rest of the entity found + visited: indexes of the characters used for this last entity part OF THE DESIGNATED TAG. Do not process again later + """ + original_index = index + last_loop_index = index + research = True + visited = [] + while research: + while ( + index < len(characters) + and labels[index] != NOT_ENTITY_TAG + and get_type_label(labels[index]) != tag + ): + index += 1 + while ( + index < len(characters) + and get_position_label(labels[index]) not in BEGINNING_POS + and get_type_label(labels[index]) == tag + ): + visited.append(index) + index += 1 + + research = index != last_loop_index and get_type_label(labels[index - 1]) == tag + last_loop_index = index + + characters_to_add = ( + characters[original_index:index] + if get_type_label(labels[index - 1]) == tag + else [] + ) + + return characters_to_add, visited diff --git a/nerval/utils.py b/nerval/utils.py new file mode 100644 index 0000000..595299e --- /dev/null +++ b/nerval/utils.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +import termtables as tt + + +def print_results(scores: dict): + """Display final results. + + None values are kept to indicate the absence of a certain tag in either annotation or prediction. + """ + header = ["tag", "predicted", "matched", "Precision", "Recall", "F1", "Support"] + results = [] + for tag in sorted(scores, reverse=True): + prec = None if scores[tag]["P"] is None else round(scores[tag]["P"], 3) + rec = None if scores[tag]["R"] is None else round(scores[tag]["R"], 3) + f1 = None if scores[tag]["F1"] is None else round(scores[tag]["F1"], 3) + + results.append( + [ + tag, + scores[tag]["predicted"], + scores[tag]["matched"], + prec, + rec, + f1, + scores[tag]["Support"], + ] + ) + tt.print(results, header, style=tt.styles.markdown) + + +def print_result_compact(scores: dict): + header = ["tag", "predicted", "matched", "Precision", "Recall", "F1", "Support"] + result = [ + [ + "All", + scores["All"]["predicted"], + scores["All"]["matched"], + round(scores["All"]["P"], 3), + round(scores["All"]["R"], 3), + round(scores["All"]["F1"], 3), + scores["All"]["Support"], + ] + ] + tt.print(result, header, style=tt.styles.markdown) diff --git a/setup.py b/setup.py index d98c59d..427ffda 100644 --- a/setup.py +++ b/setup.py @@ -1,24 +1,38 @@ +#!/usr/bin/env python # -*- coding: utf-8 -*- -import os.path +from pathlib import Path -from setuptools import setup +from setuptools import find_packages, setup -def requirements(path): - assert os.path.exists(path), "Missing requirements {}.format(path)" - with open(path) as f: - return list(map(str.strip, f.read().splitlines())) +def parse_requirements_line(line): + """Special case for git requirements""" + if line.startswith("git+http"): + assert "@" in line, "Branch should be specified with suffix (ex: @master)" + assert ( + "#egg=" in line + ), "Package name should be specified with suffix (ex: #egg=kraken)" + package_name = line.split("#egg=")[-1] + return f"{package_name} @ {line}" + else: + return line -install_requires = requirements("requirements.txt") +def parse_requirements(): + path = Path(__file__).parent.resolve() / "requirements.txt" + assert path.exists(), f"Missing requirements: {path}" + return list( + map(parse_requirements_line, map(str.strip, path.read_text().splitlines())) + ) + setup( - name="Nerval", + name="nerval", version=open("VERSION").read(), description="Tool to evaluate NER on noisy text.", author="Teklia", - author_email="bmiret@teklia.com", - packages=["nerval"], - entry_points={"console_scripts": ["nerval=nerval.evaluate:main"]}, - install_requires=install_requires, + author_email="contact@teklia.com", + packages=find_packages(), + entry_points={"console_scripts": ["nerval=nerval.cli:main"]}, + install_requires=parse_requirements(), ) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d73df4f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +from pathlib import Path + +import pytest + +FIXTURES = Path(__file__).parent / "fixtures" + + +@pytest.fixture +def fake_annot_bio(): + return FIXTURES / "test_annot.bio" + + +@pytest.fixture +def fake_predict_bio(): + return FIXTURES / "test_predict.bio" + + +@pytest.fixture +def empty_bio(): + return FIXTURES / "test_empty.bio" + + +@pytest.fixture +def bad_bio(): + return FIXTURES / "test_bad.bio" + + +@pytest.fixture +def bioeslu_bio(): + return FIXTURES / "bioeslu.bio" + + +@pytest.fixture +def end_of_file_bio(): + return FIXTURES / "end_of_file.bio" + + +@pytest.fixture +def nested_bio(): + return FIXTURES / "test_nested.bio" + + +@pytest.fixture +def folder_bio(): + return Path("test_folder") + + +@pytest.fixture() +def csv_file(): + return Path("test_mapping_file.csv") diff --git a/tests/bioeslu.bio b/tests/fixtures/bioeslu.bio similarity index 100% rename from tests/bioeslu.bio rename to tests/fixtures/bioeslu.bio diff --git a/tests/end_of_file.bio b/tests/fixtures/end_of_file.bio similarity index 100% rename from tests/end_of_file.bio rename to tests/fixtures/end_of_file.bio diff --git a/tests/test_annot.bio b/tests/fixtures/test_annot.bio similarity index 100% rename from tests/test_annot.bio rename to tests/fixtures/test_annot.bio diff --git a/tests/test_bad.bio b/tests/fixtures/test_bad.bio similarity index 100% rename from tests/test_bad.bio rename to tests/fixtures/test_bad.bio diff --git a/tests/test_empty.bio b/tests/fixtures/test_empty.bio similarity index 100% rename from tests/test_empty.bio rename to tests/fixtures/test_empty.bio diff --git a/tests/test_nested.bio b/tests/fixtures/test_nested.bio similarity index 100% rename from tests/test_nested.bio rename to tests/fixtures/test_nested.bio diff --git a/tests/test_predict.bio b/tests/fixtures/test_predict.bio similarity index 100% rename from tests/test_predict.bio rename to tests/fixtures/test_predict.bio diff --git a/tests/test_align.py b/tests/test_align.py index d432025..a18c199 100644 --- a/tests/test_align.py +++ b/tests/test_align.py @@ -2,21 +2,21 @@ import edlib import pytest -fake_annot_original = "Gérard de Nerval was born in Paris in 1808 ." -fake_predict_original = "G*rard de *N*erval bo*rn in Paris in 1833 *." - -expected_alignment = { - "query_aligned": "Gérard de -N-erval was bo-rn in Paris in 1808 -.", - "matched_aligned": "|.||||||||-|-||||||----||-|||||||||||||||||..|-|", - "target_aligned": "G*rard de *N*erval ----bo*rn in Paris in 1833 *.", -} - @pytest.mark.parametrize( - "test_input, expected", - [((fake_annot_original, fake_predict_original), expected_alignment)], + "query,target", + [ + ( + "Gérard de Nerval was born in Paris in 1808 .", + "G*rard de *N*erval bo*rn in Paris in 1833 *.", + ) + ], ) -def test_align(test_input, expected): - a = edlib.align(*test_input, task="path") - result_alignment = edlib.getNiceAlignment(a, *test_input) - assert result_alignment == expected +def test_align(query, target): + a = edlib.align(query, target, task="path") + result_alignment = edlib.getNiceAlignment(a, query, target) + assert result_alignment == { + "query_aligned": "Gérard de -N-erval was bo-rn in Paris in 1808 -.", + "matched_aligned": "|.||||||||-|-||||||----||-|||||||||||||||||..|-|", + "target_aligned": "G*rard de *N*erval ----bo*rn in Paris in 1833 *.", + } diff --git a/tests/test_compute_matches.py b/tests/test_compute_matches.py index f472e84..1f62ce0 100644 --- a/tests/test_compute_matches.py +++ b/tests/test_compute_matches.py @@ -5,10 +5,6 @@ from nerval import evaluate THRESHOLD = 0.30 -fake_annot_aligned = "Gérard de -N-erval was bo-rn in Paris in 1808 -." -fake_predict_aligned = "G*rard de *N*erval ----bo*rn in Paris in 1833 *." - -fake_string_nested = "Louis par la grâce de Dieu roy de France et de Navarre." # fmt: off fake_tags_aligned_nested_perfect = [ @@ -141,12 +137,6 @@ fake_annot_tags_aligned = [ "O", ] -expected_matches = {"All": 1, "PER": 1, "LOC": 0, "DAT": 0} -expected_matches_nested_perfect = {"All": 3, "PER": 1, "LOC": 2} -expected_matches_nested_false = {"All": 2, "PER": 1, "LOC": 1} - -fake_annot_backtrack_boundary = "The red dragon" - fake_annot_tags_bk_boundary = [ "O", "O", @@ -181,10 +171,6 @@ fake_predict_tags_bk_boundary = [ "I-PER", ] -expected_matches_bk_boundary = {"All": 0, "PER": 0} - -fake_annot_backtrack_boundary_2 = "A red dragon" - fake_annot_tags_bk_boundary_2 = [ "O", "O", @@ -215,61 +201,59 @@ fake_predict_tags_bk_boundary_2 = [ "I-PER", ] -expected_matches_bk_boundary_2 = {"All": 1, "PER": 1} - @pytest.mark.parametrize( "test_input, expected", [ ( ( - fake_annot_aligned, - fake_predict_aligned, + "Gérard de -N-erval was bo-rn in Paris in 1808 -.", + "G*rard de *N*erval ----bo*rn in Paris in 1833 *.", fake_annot_tags_aligned, fake_predict_tags_aligned, THRESHOLD, ), - expected_matches, + {"All": 1, "PER": 1, "LOC": 0, "DAT": 0}, ), ( ( - fake_string_nested, - fake_string_nested, + "Louis par la grâce de Dieu roy de France et de Navarre.", + "Louis par la grâce de Dieu roy de France et de Navarre.", fake_tags_aligned_nested_perfect, fake_tags_aligned_nested_perfect, THRESHOLD, ), - expected_matches_nested_perfect, + {"All": 3, "PER": 1, "LOC": 2}, ), ( ( - fake_string_nested, - fake_string_nested, + "Louis par la grâce de Dieu roy de France et de Navarre.", + "Louis par la grâce de Dieu roy de France et de Navarre.", fake_tags_aligned_nested_perfect, fake_tags_aligned_nested_false, THRESHOLD, ), - expected_matches_nested_false, + {"All": 2, "PER": 1, "LOC": 1}, ), ( ( - fake_annot_backtrack_boundary, - fake_annot_backtrack_boundary, + "The red dragon", + "The red dragon", fake_annot_tags_bk_boundary, fake_predict_tags_bk_boundary, THRESHOLD, ), - expected_matches_bk_boundary, + {"All": 0, "PER": 0}, ), ( ( - fake_annot_backtrack_boundary_2, - fake_annot_backtrack_boundary_2, + "A red dragon", + "A red dragon", fake_annot_tags_bk_boundary_2, fake_predict_tags_bk_boundary_2, THRESHOLD, ), - expected_matches_bk_boundary_2, + {"All": 1, "PER": 1}, ), ], ) diff --git a/tests/test_compute_scores.py b/tests/test_compute_scores.py index 02e26cf..e4dc730 100644 --- a/tests/test_compute_scores.py +++ b/tests/test_compute_scores.py @@ -3,48 +3,57 @@ import pytest from nerval import evaluate -fake_annot_entity_count = {"All": 3, "DAT": 1, "LOC": 1, "PER": 1} -fake_predict_entity_count = {"All": 3, "DAT": 1, "***": 1, "PER": 1} -fake_matches = {"All": 1, "PER": 1, "LOC": 0, "DAT": 0} - -expected_scores = { - "***": { - "P": 0.0, - "R": None, - "F1": None, - "predicted": 1, - "matched": 0, - "Support": None, - }, - "DAT": {"P": 0.0, "R": 0.0, "F1": 0, "predicted": 1, "matched": 0, "Support": 1}, - "All": { - "P": 0.3333333333333333, - "R": 0.3333333333333333, - "F1": 0.3333333333333333, - "predicted": 3, - "matched": 1, - "Support": 3, - }, - "PER": {"P": 1.0, "R": 1.0, "F1": 1.0, "predicted": 1, "matched": 1, "Support": 1}, - "LOC": { - "P": None, - "R": 0.0, - "F1": None, - "predicted": None, - "matched": 0, - "Support": 1, - }, -} - @pytest.mark.parametrize( - "test_input, expected", + "annot,predict,matches", [ ( - (fake_annot_entity_count, fake_predict_entity_count, fake_matches), - expected_scores, + {"All": 3, "DAT": 1, "LOC": 1, "PER": 1}, + {"All": 3, "DAT": 1, "***": 1, "PER": 1}, + {"All": 1, "PER": 1, "LOC": 0, "DAT": 0}, ) ], ) -def test_compute_scores(test_input, expected): - assert evaluate.compute_scores(*test_input) == expected +def test_compute_scores(annot, predict, matches): + assert evaluate.compute_scores(annot, predict, matches) == { + "***": { + "P": 0.0, + "R": None, + "F1": None, + "predicted": 1, + "matched": 0, + "Support": None, + }, + "DAT": { + "P": 0.0, + "R": 0.0, + "F1": 0, + "predicted": 1, + "matched": 0, + "Support": 1, + }, + "All": { + "P": 0.3333333333333333, + "R": 0.3333333333333333, + "F1": 0.3333333333333333, + "predicted": 3, + "matched": 1, + "Support": 3, + }, + "PER": { + "P": 1.0, + "R": 1.0, + "F1": 1.0, + "predicted": 1, + "matched": 1, + "Support": 1, + }, + "LOC": { + "P": None, + "R": 0.0, + "F1": None, + "predicted": None, + "matched": 0, + "Support": 1, + }, + } diff --git a/tests/test_parse_bio.py b/tests/test_parse_bio.py index 23df00b..db161b8 100644 --- a/tests/test_parse_bio.py +++ b/tests/test_parse_bio.py @@ -1,17 +1,10 @@ # -*- coding: utf-8 -*- +from pathlib import Path + import pytest from nerval import evaluate -NO_EXIST_BIO = "no_exist.bio" -EMPTY_BIO = "tests/test_empty.bio" -BAD_BIO = "tests/test_bad.bio" -FAKE_ANNOT_BIO = "tests/test_annot.bio" -FAKE_PREDICT_BIO = "tests/test_predict.bio" -BIOESLU_BIO = "tests/bioeslu.bio" -END_OF_FILE_BIO = "tests/end_of_file.bio" - - expected_parsed_annot = { "entity_count": {"All": 3, "DAT": 1, "LOC": 1, "PER": 1}, "labels": [ @@ -179,22 +172,22 @@ expected_parsed_end_of_file = { @pytest.mark.parametrize( "test_input, expected", [ - (FAKE_ANNOT_BIO, expected_parsed_annot), - (FAKE_PREDICT_BIO, expected_parsed_predict), - (EMPTY_BIO, None), - (BIOESLU_BIO, expected_parsed_annot), - (END_OF_FILE_BIO, expected_parsed_end_of_file), + (pytest.lazy_fixture("fake_annot_bio"), expected_parsed_annot), + (pytest.lazy_fixture("fake_predict_bio"), expected_parsed_predict), + (pytest.lazy_fixture("empty_bio"), None), + (pytest.lazy_fixture("bioeslu_bio"), expected_parsed_annot), + (pytest.lazy_fixture("end_of_file_bio"), expected_parsed_end_of_file), ], ) def test_parse_bio(test_input, expected): assert evaluate.parse_bio(test_input) == expected -def test_parse_bio_bad_input(): +def test_parse_bio_bad_input(bad_bio): with pytest.raises(Exception): - evaluate.parse_bio(BAD_BIO) + evaluate.parse_bio(bad_bio) def test_parse_bio_no_input(): with pytest.raises(AssertionError): - evaluate.parse_bio(NO_EXIST_BIO) + evaluate.parse_bio(Path("not_a_bio")) diff --git a/tests/test_run.py b/tests/test_run.py index 4a6e9d5..b88595a 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -3,88 +3,111 @@ import pytest from nerval import evaluate -THRESHOLD = 0.30 - -FAKE_ANNOT_BIO = "tests/test_annot.bio" -FAKE_PREDICT_BIO = "tests/test_predict.bio" -EMPTY_BIO = "tests/test_empty.bio" -FAKE_BIO_NESTED = "tests/test_nested.bio" -BIO_FOLDER = "test_folder" -CSV_FILE = "test_mapping_file.csv" - -expected_scores_nested = { - "All": { - "P": 1.0, - "R": 1.0, - "F1": 1.0, - "predicted": 3, - "matched": 3, - "Support": 3, - }, - "PER": {"P": 1.0, "R": 1.0, "F1": 1.0, "predicted": 1, "matched": 1, "Support": 1}, - "LOC": { - "P": 1.0, - "R": 1.0, - "F1": 1.0, - "predicted": 2, - "matched": 2, - "Support": 2, - }, -} - - -expected_scores = { - "***": { - "P": 0.0, - "R": None, - "F1": None, - "predicted": 1, - "matched": 0, - "Support": None, - }, - "DAT": {"P": 0.0, "R": 0.0, "F1": 0, "predicted": 1, "matched": 0, "Support": 1}, - "All": { - "P": 0.3333333333333333, - "R": 0.3333333333333333, - "F1": 0.3333333333333333, - "predicted": 3, - "matched": 1, - "Support": 3, - }, - "PER": {"P": 1.0, "R": 1.0, "F1": 1.0, "predicted": 1, "matched": 1, "Support": 1}, - "LOC": { - "P": None, - "R": 0.0, - "F1": None, - "predicted": None, - "matched": 0, - "Support": 1, - }, -} - @pytest.mark.parametrize( - "test_input, expected", - [ - ((FAKE_ANNOT_BIO, FAKE_PREDICT_BIO, THRESHOLD, True), expected_scores), - ((FAKE_BIO_NESTED, FAKE_BIO_NESTED, THRESHOLD, True), expected_scores_nested), - ], + "annotation, prediction, expected", + ( + ( + pytest.lazy_fixture("fake_annot_bio"), + pytest.lazy_fixture("fake_predict_bio"), + { + "***": { + "P": 0.0, + "R": None, + "F1": None, + "predicted": 1, + "matched": 0, + "Support": None, + }, + "DAT": { + "P": 0.0, + "R": 0.0, + "F1": 0, + "predicted": 1, + "matched": 0, + "Support": 1, + }, + "All": { + "P": 0.3333333333333333, + "R": 0.3333333333333333, + "F1": 0.3333333333333333, + "predicted": 3, + "matched": 1, + "Support": 3, + }, + "PER": { + "P": 1.0, + "R": 1.0, + "F1": 1.0, + "predicted": 1, + "matched": 1, + "Support": 1, + }, + "LOC": { + "P": None, + "R": 0.0, + "F1": None, + "predicted": None, + "matched": 0, + "Support": 1, + }, + }, + ), + ( + pytest.lazy_fixture("nested_bio"), + pytest.lazy_fixture("nested_bio"), + { + "All": { + "P": 1.0, + "R": 1.0, + "F1": 1.0, + "predicted": 3, + "matched": 3, + "Support": 3, + }, + "PER": { + "P": 1.0, + "R": 1.0, + "F1": 1.0, + "predicted": 1, + "matched": 1, + "Support": 1, + }, + "LOC": { + "P": 1.0, + "R": 1.0, + "F1": 1.0, + "predicted": 2, + "matched": 2, + "Support": 2, + }, + }, + ), + ), ) -def test_run(test_input, expected): - # print(evaluate.run(*test_input)) - assert evaluate.run(*test_input) == expected +def test_run(annotation, prediction, expected): + assert ( + evaluate.run( + annotation=annotation, + prediction=prediction, + threshold=0.3, + verbose=False, + ) + == expected + ) -def test_run_empty_bio(): +def test_run_empty_bio(empty_bio): with pytest.raises(Exception): - evaluate.run(EMPTY_BIO, EMPTY_BIO, THRESHOLD) + evaluate.run(empty_bio, empty_bio, 0.3) def test_run_empty_entry(): with pytest.raises(TypeError): - evaluate.run(None, None, THRESHOLD) + evaluate.run(None, None, 0.3) -def test_run_multiple(): +@pytest.mark.parametrize("threshold", ([0.3])) +def test_run_multiple(csv_file, folder_bio, threshold): with pytest.raises(Exception): - evaluate.run_multiple(CSV_FILE, BIO_FOLDER, THRESHOLD) + evaluate.run_multiple(csv_file, folder_bio, threshold) diff --git a/tox.ini b/tox.ini index 9d9d14c..a8d512b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,11 +1,11 @@ [tox] envlist = nerval -skipsdist=True -[testenv:nerval] +[testenv] commands = pytest {posargs} deps = pytest + pytest-lazy-fixture -rrequirements.txt -- GitLab