Skip to content
Snippets Groups Projects
Commit 2ac2b607 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Some small refactoring

parent 25a70324
No related branches found
No related tags found
1 merge request!20Some small refactoring
Pipeline #104003 passed
Showing
with 472 additions and 404 deletions
stages: stages:
- test - test
- build
- release - release
variables: variables:
...@@ -11,7 +10,7 @@ cache: ...@@ -11,7 +10,7 @@ cache:
linter: linter:
stage: test stage: test
image: python:3.8 image: python:3
cache: cache:
paths: paths:
...@@ -32,7 +31,7 @@ linter: ...@@ -32,7 +31,7 @@ linter:
tests: tests:
stage: test stage: test
image: python:3.8 image: python:3
cache: cache:
paths: paths:
......
repos: repos:
- repo: https://github.com/pre-commit/mirrors-isort - repo: https://github.com/PyCQA/isort
rev: v5.10.1 rev: 5.12.0
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/ambv/black - repo: https://github.com/ambv/black
rev: 22.10.0 rev: 23.1.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/pycqa/flake8
rev: 3.9.2 rev: 6.0.0
hooks: hooks:
- id: flake8 - id: flake8
additional_dependencies: additional_dependencies:
- 'flake8-coding==1.3.1' - 'flake8-coding==1.3.2'
- 'flake8-copyright==0.2.2' - 'flake8-debugger==4.1.2'
- 'flake8-debugger==3.1.0'
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0 rev: v4.4.0
hooks: hooks:
- id: check-ast - id: check-ast
- id: check-docstring-first - id: check-docstring-first
......
include README.md include requirements.txt
include VERSION
# -*- coding: utf-8 -*-
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s/%(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
# -*- coding: utf-8 -*-
import argparse
from pathlib import Path
from nerval.evaluate import run, run_multiple
THRESHOLD = 0.30
def threshold_float_type(arg):
"""Type function for argparse."""
try:
f = float(arg)
except ValueError:
raise argparse.ArgumentTypeError("Must be a floating point number.")
if f < 0 or f > 1:
raise argparse.ArgumentTypeError("Must be between 0 and 1.")
return f
def parse_args():
"""Get arguments and run."""
parser = argparse.ArgumentParser(description="Compute score of NER on predict.")
group = parser.add_mutually_exclusive_group()
group.add_argument(
"-a",
"--annot",
help="Annotation in BIO format.",
)
group.add_argument(
"-c",
"--csv",
help="CSV with the correlation between the annotation bio files and the predict bio files",
type=Path,
)
parser.add_argument(
"-p",
"--predict",
help="Prediction in BIO format.",
)
parser.add_argument(
"-f",
"--folder",
help="Folder containing the bio files referred to in the csv file",
type=Path,
)
parser.add_argument(
"-v",
"--verbose",
help="Print only the recap if False",
action="store_false",
)
parser.add_argument(
"-t",
"--threshold",
help="Set a distance threshold for the match between gold and predicted entity.",
default=THRESHOLD,
type=threshold_float_type,
)
return parser.parse_args()
def main():
args = parse_args()
if args.annot:
if not args.predict:
raise argparse.ArgumentTypeError(
"You need to specify the path to a predict file with -p"
)
run(args.annot, args.predict, args.threshold, args.verbose)
elif args.csv:
if not args.folder:
raise argparse.ArgumentTypeError(
"You need to specify the path to a folder of bio files with -f"
)
run_multiple(args.csv, args.folder, args.threshold, args.verbose)
else:
raise argparse.ArgumentTypeError(
"You need to specify the argument of input file"
)
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import argparse
import glob import glob
import logging import logging
import os import os
import re
from csv import reader from csv import reader
from pathlib import Path from pathlib import Path
...@@ -12,211 +10,17 @@ import editdistance ...@@ -12,211 +10,17 @@ import editdistance
import edlib import edlib
import termtables as tt import termtables as tt
NOT_ENTITY_TAG = "O" from nerval.parse import (
BEGINNING_POS,
NOT_ENTITY_TAG,
get_position_label,
get_type_label,
look_for_further_entity_part,
parse_bio,
)
from nerval.utils import print_result_compact, print_results
THRESHOLD = 0.30 logger = logging.getLogger(__name__)
BEGINNING_POS = ["B", "S", "U"]
def get_type_label(label: str) -> str:
"""Return the type (tag) of a label
Input format: "[BIESLU]-type"
"""
try:
tag = (
NOT_ENTITY_TAG
if label == NOT_ENTITY_TAG
else re.match(r"[BIESLU]-(.*)$", label)[1]
)
except TypeError:
raise (Exception(f"The label {label} is not valid in BIOES/BIOLU format."))
return tag
def get_position_label(label: str) -> str:
"""Return the position of a label
Input format: "[BIESLU]-type"
"""
try:
pos = (
NOT_ENTITY_TAG
if label == NOT_ENTITY_TAG
else re.match(r"([BIESLU])-(.*)$", label)[1]
)
except TypeError:
raise (Exception(f"The label {label} is not valid in BIOES/BIOLU format."))
return pos
def parse_bio(path: str) -> dict:
"""Parse a BIO file to get text content, character-level NE labels and entity types count.
Input : path to a valid BIO file
Output format : { "words": str, "labels": list; "entity_count" : { tag : int } }
"""
assert os.path.exists(path), f"Error: Input file {path} does not exist"
words = []
labels = []
entity_count = {"All": 0}
last_tag = None
with open(path, "r") as fd:
lines = list(filter(lambda x: x != "\n", fd.readlines()))
if "§" in " ".join(lines):
raise (
Exception(
f"§ found in input file {path}. Since this character is used in a specific way during evaluation, prease remove it from files."
)
)
# Track nested entities infos
in_nested_entity = False
containing_tag = None
for index, line in enumerate(lines):
try:
word, label = line.split()
except ValueError:
raise (
Exception(
f"The file {path} given in input is not in BIO format: check line {index} ({line})"
)
)
# Preserve hyphens to avoid confusion with the hyphens added later during alignment
word = word.replace("-", "§")
words.append(word)
tag = get_type_label(label)
# Spaces will be added between words and have to get a label
if index != 0:
# If new word has same tag as previous, not new entity and in entity, continue entity
if (
last_tag == tag
and get_position_label(label) not in BEGINNING_POS
and tag != NOT_ENTITY_TAG
):
labels.append(f"I-{last_tag}")
# If new word begins a new entity of different type, check for nested entity to correctly tag the space
elif (
last_tag != tag
and get_position_label(label) in BEGINNING_POS
and tag != NOT_ENTITY_TAG
and last_tag != NOT_ENTITY_TAG
):
# Advance to next word with different label as current
future_label = label
while (
index < len(lines)
and future_label != NOT_ENTITY_TAG
and get_type_label(future_label) != last_tag
):
index += 1
if index < len(lines):
future_label = lines[index].split()[1]
# Check for continuation of the original entity
if (
index < len(lines)
and get_position_label(future_label) not in BEGINNING_POS
and get_type_label(future_label) == last_tag
):
labels.append(f"I-{last_tag}")
in_nested_entity = True
containing_tag = last_tag
else:
labels.append(NOT_ENTITY_TAG)
in_nested_entity = False
elif in_nested_entity:
labels.append(f"I-{containing_tag}")
else:
labels.append(NOT_ENTITY_TAG)
in_nested_entity = False
# Add a tag for each letter in the word
if get_position_label(label) in BEGINNING_POS:
labels += [f"B-{tag}"] + [f"I-{tag}"] * (len(word) - 1)
else:
labels += [label] * len(word)
# Count nb entity for each type
if get_position_label(label) in BEGINNING_POS:
entity_count[tag] = entity_count.get(tag, 0) + 1
entity_count["All"] += 1
last_tag = tag
result = None
if words:
result = dict()
result["words"] = " ".join(words)
result["labels"] = labels
result["entity_count"] = entity_count
assert len(result["words"]) == len(result["labels"])
for tag in result["entity_count"]:
if tag != "All":
assert result["labels"].count(f"B-{tag}") == result["entity_count"][tag]
return result
def look_for_further_entity_part(index, tag, characters, labels):
"""Get further entities parts for long entities with nested entities.
Input:
index: the starting index to look for rest of entity (one after last character included)
tag: the type of the entity investigated
characters: the string of the annotation or prediction
the labels associated with characters
Output :
complete string of the rest of the entity found
visited: indexes of the characters used for this last entity part OF THE DESIGNATED TAG. Do not process again later
"""
original_index = index
last_loop_index = index
research = True
visited = []
while research:
while (
index < len(characters)
and labels[index] != NOT_ENTITY_TAG
and get_type_label(labels[index]) != tag
):
index += 1
while (
index < len(characters)
and get_position_label(labels[index]) not in BEGINNING_POS
and get_type_label(labels[index]) == tag
):
visited.append(index)
index += 1
research = index != last_loop_index and get_type_label(labels[index - 1]) == tag
last_loop_index = index
characters_to_add = (
characters[original_index:index]
if get_type_label(labels[index - 1]) == tag
else []
)
return characters_to_add, visited
def compute_matches( def compute_matches(
...@@ -268,7 +72,6 @@ def compute_matches( ...@@ -268,7 +72,6 @@ def compute_matches(
# Iterating on reference string # Iterating on reference string
for i, char_annot in enumerate(annotation): for i, char_annot in enumerate(annotation):
if i in visited_annot: if i in visited_annot:
continue continue
...@@ -282,7 +85,6 @@ def compute_matches( ...@@ -282,7 +85,6 @@ def compute_matches(
last_tag = NOT_ENTITY_TAG last_tag = NOT_ENTITY_TAG
else: else:
# If beginning new entity # If beginning new entity
if get_position_label(label_ref) in BEGINNING_POS: if get_position_label(label_ref) in BEGINNING_POS:
current_ref, current_compar = [], [] current_ref, current_compar = [], []
...@@ -294,7 +96,6 @@ def compute_matches( ...@@ -294,7 +96,6 @@ def compute_matches(
# Searching character string corresponding with tag # Searching character string corresponding with tag
if not found_aligned_end and tag_predict == tag_ref: if not found_aligned_end and tag_predict == tag_ref:
if i in visited_predict: if i in visited_predict:
continue continue
...@@ -335,7 +136,6 @@ def compute_matches( ...@@ -335,7 +136,6 @@ def compute_matches(
i + 1 < len(annotation) i + 1 < len(annotation)
and get_type_label(labels_annot[i + 1]) != last_tag and get_type_label(labels_annot[i + 1]) != last_tag
): ):
if not found_aligned_end: if not found_aligned_end:
rest_predict, visited = look_for_further_entity_part( rest_predict, visited = look_for_further_entity_part(
i + 1, tag_ref, prediction, labels_predict i + 1, tag_ref, prediction, labels_predict
...@@ -466,50 +266,7 @@ def compute_scores( ...@@ -466,50 +266,7 @@ def compute_scores(
return scores return scores
def print_results(scores: dict): def run(annotation: Path, prediction: Path, threshold: int, verbose: bool) -> dict:
"""Display final results.
None values are kept to indicate the absence of a certain tag in either annotation or prediction.
"""
header = ["tag", "predicted", "matched", "Precision", "Recall", "F1", "Support"]
results = []
for tag in sorted(scores.keys())[::-1]:
prec = None if scores[tag]["P"] is None else round(scores[tag]["P"], 3)
rec = None if scores[tag]["R"] is None else round(scores[tag]["R"], 3)
f1 = None if scores[tag]["F1"] is None else round(scores[tag]["F1"], 3)
results.append(
[
tag,
scores[tag]["predicted"],
scores[tag]["matched"],
prec,
rec,
f1,
scores[tag]["Support"],
]
)
tt.print(results, header, style=tt.styles.markdown)
def print_result_compact(scores: dict):
result = []
header = ["tag", "predicted", "matched", "Precision", "Recall", "F1", "Support"]
result.append(
[
"ALl",
scores["All"]["predicted"],
scores["All"]["matched"],
round(scores["All"]["P"], 3),
round(scores["All"]["R"], 3),
round(scores["All"]["F1"], 3),
scores["All"]["Support"],
]
)
tt.print(result, header, style=tt.styles.markdown)
def run(annotation: str, prediction: str, threshold: int, verbose: bool) -> dict:
"""Compute recall and precision for each entity type found in annotation and/or prediction. """Compute recall and precision for each entity type found in annotation and/or prediction.
Each measure is given at document level, global score is a micro-average across entity types. Each measure is given at document level, global score is a micro-average across entity types.
...@@ -586,16 +343,14 @@ def run_multiple(file_csv, folder, threshold, verbose): ...@@ -586,16 +343,14 @@ def run_multiple(file_csv, folder, threshold, verbose):
if annot and predict: if annot and predict:
count += 1 count += 1
print(os.path.basename(predict))
scores = run(annot, predict, threshold, verbose) scores = run(annot, predict, threshold, verbose)
precision += scores["All"]["P"] precision += scores["All"]["P"]
recall += scores["All"]["R"] recall += scores["All"]["R"]
f1 += scores["All"]["F1"] f1 += scores["All"]["F1"]
print()
else: else:
raise Exception(f"No file found for files {annot}, {predict}") raise Exception(f"No file found for files {annot}, {predict}")
if count: if count:
print("Average score on all corpus") logger.info("Average score on all corpus")
tt.print( tt.print(
[ [
[ [
...@@ -611,80 +366,3 @@ def run_multiple(file_csv, folder, threshold, verbose): ...@@ -611,80 +366,3 @@ def run_multiple(file_csv, folder, threshold, verbose):
raise Exception("No file were counted") raise Exception("No file were counted")
else: else:
raise Exception("the path indicated does not lead to a folder.") raise Exception("the path indicated does not lead to a folder.")
def threshold_float_type(arg):
"""Type function for argparse."""
try:
f = float(arg)
except ValueError:
raise argparse.ArgumentTypeError("Must be a floating point number.")
if f < 0 or f > 1:
raise argparse.ArgumentTypeError("Must be between 0 and 1.")
return f
def main():
"""Get arguments and run."""
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="Compute score of NER on predict.")
group = parser.add_mutually_exclusive_group()
group.add_argument(
"-a",
"--annot",
help="Annotation in BIO format.",
)
group.add_argument(
"-c",
"--csv",
help="Csv with the correlation between the annotation bio files and the predict bio files",
type=Path,
)
parser.add_argument(
"-p",
"--predict",
help="Prediction in BIO format.",
)
parser.add_argument(
"-f",
"--folder",
help="Folder containing the bio files referred to in the csv file",
type=Path,
)
parser.add_argument(
"-v",
"--verbose",
help="Print only the recap if False",
action="store_false",
)
parser.add_argument(
"-t",
"--threshold",
help="Set a distance threshold for the match between gold and predicted entity.",
default=THRESHOLD,
type=threshold_float_type,
)
args = parser.parse_args()
if args.annot:
if not args.predict:
raise parser.error("You need to specify the path to a predict file with -p")
if args.annot and args.predict:
run(args.annot, args.predict, args.threshold, args.verbose)
elif args.csv:
if not args.folder:
raise parser.error(
"You need to specify the path to a folder of bio files with -f"
)
if args.folder and args.csv:
run_multiple(args.csv, args.folder, args.threshold, args.verbose)
else:
raise parser.error("You need to specify the argument of input file")
if __name__ == "__main__":
main()
# -*- coding: utf-8 -*-
import re
from pathlib import Path
NOT_ENTITY_TAG = "O"
BEGINNING_POS = ["B", "S", "U"]
def get_type_label(label: str) -> str:
"""Return the type (tag) of a label
Input format: "[BIESLU]-type"
"""
try:
tag = (
NOT_ENTITY_TAG
if label == NOT_ENTITY_TAG
else re.match(r"[BIESLU]-(.*)$", label)[1]
)
except TypeError:
raise (Exception(f"The label {label} is not valid in BIOES/BIOLU format."))
return tag
def get_position_label(label: str) -> str:
"""Return the position of a label
Input format: "[BIESLU]-type"
"""
try:
pos = (
NOT_ENTITY_TAG
if label == NOT_ENTITY_TAG
else re.match(r"([BIESLU])-(.*)$", label)[1]
)
except TypeError:
raise (Exception(f"The label {label} is not valid in BIOES/BIOLU format."))
return pos
def parse_bio(path: Path) -> dict:
"""Parse a BIO file to get text content, character-level NE labels and entity types count.
Input : path to a valid BIO file
Output format : { "words": str, "labels": list; "entity_count" : { tag : int } }
"""
assert path.exists(), f"Error: Input file {path} does not exist"
words = []
labels = []
entity_count = {"All": 0}
last_tag = None
with open(path, "r") as fd:
lines = list(filter(lambda x: x != "\n", fd.readlines()))
if "§" in " ".join(lines):
raise (
Exception(
f"§ found in input file {path}. Since this character is used in a specific way during evaluation, prease remove it from files."
)
)
# Track nested entities infos
in_nested_entity = False
containing_tag = None
for index, line in enumerate(lines):
try:
word, label = line.split()
except ValueError:
raise (
Exception(
f"The file {path} given in input is not in BIO format: check line {index} ({line})"
)
)
# Preserve hyphens to avoid confusion with the hyphens added later during alignment
word = word.replace("-", "§")
words.append(word)
tag = get_type_label(label)
# Spaces will be added between words and have to get a label
if index != 0:
# If new word has same tag as previous, not new entity and in entity, continue entity
if (
last_tag == tag
and get_position_label(label) not in BEGINNING_POS
and tag != NOT_ENTITY_TAG
):
labels.append(f"I-{last_tag}")
# If new word begins a new entity of different type, check for nested entity to correctly tag the space
elif (
last_tag != tag
and get_position_label(label) in BEGINNING_POS
and tag != NOT_ENTITY_TAG
and last_tag != NOT_ENTITY_TAG
):
# Advance to next word with different label as current
future_label = label
while (
index < len(lines)
and future_label != NOT_ENTITY_TAG
and get_type_label(future_label) != last_tag
):
index += 1
if index < len(lines):
future_label = lines[index].split()[1]
# Check for continuation of the original entity
if (
index < len(lines)
and get_position_label(future_label) not in BEGINNING_POS
and get_type_label(future_label) == last_tag
):
labels.append(f"I-{last_tag}")
in_nested_entity = True
containing_tag = last_tag
else:
labels.append(NOT_ENTITY_TAG)
in_nested_entity = False
elif in_nested_entity:
labels.append(f"I-{containing_tag}")
else:
labels.append(NOT_ENTITY_TAG)
in_nested_entity = False
# Add a tag for each letter in the word
if get_position_label(label) in BEGINNING_POS:
labels += [f"B-{tag}"] + [f"I-{tag}"] * (len(word) - 1)
else:
labels += [label] * len(word)
# Count nb entity for each type
if get_position_label(label) in BEGINNING_POS:
entity_count[tag] = entity_count.get(tag, 0) + 1
entity_count["All"] += 1
last_tag = tag
result = None
if words:
result = dict()
result["words"] = " ".join(words)
result["labels"] = labels
result["entity_count"] = entity_count
assert len(result["words"]) == len(result["labels"])
for tag in result["entity_count"]:
if tag != "All":
assert result["labels"].count(f"B-{tag}") == result["entity_count"][tag]
return result
def look_for_further_entity_part(index, tag, characters, labels):
"""Get further entities parts for long entities with nested entities.
Input:
index: the starting index to look for rest of entity (one after last character included)
tag: the type of the entity investigated
characters: the string of the annotation or prediction
the labels associated with characters
Output :
complete string of the rest of the entity found
visited: indexes of the characters used for this last entity part OF THE DESIGNATED TAG. Do not process again later
"""
original_index = index
last_loop_index = index
research = True
visited = []
while research:
while (
index < len(characters)
and labels[index] != NOT_ENTITY_TAG
and get_type_label(labels[index]) != tag
):
index += 1
while (
index < len(characters)
and get_position_label(labels[index]) not in BEGINNING_POS
and get_type_label(labels[index]) == tag
):
visited.append(index)
index += 1
research = index != last_loop_index and get_type_label(labels[index - 1]) == tag
last_loop_index = index
characters_to_add = (
characters[original_index:index]
if get_type_label(labels[index - 1]) == tag
else []
)
return characters_to_add, visited
# -*- coding: utf-8 -*-
import termtables as tt
def print_results(scores: dict):
"""Display final results.
None values are kept to indicate the absence of a certain tag in either annotation or prediction.
"""
header = ["tag", "predicted", "matched", "Precision", "Recall", "F1", "Support"]
results = []
for tag in sorted(scores, reverse=True):
prec = None if scores[tag]["P"] is None else round(scores[tag]["P"], 3)
rec = None if scores[tag]["R"] is None else round(scores[tag]["R"], 3)
f1 = None if scores[tag]["F1"] is None else round(scores[tag]["F1"], 3)
results.append(
[
tag,
scores[tag]["predicted"],
scores[tag]["matched"],
prec,
rec,
f1,
scores[tag]["Support"],
]
)
tt.print(results, header, style=tt.styles.markdown)
def print_result_compact(scores: dict):
header = ["tag", "predicted", "matched", "Precision", "Recall", "F1", "Support"]
result = [
[
"All",
scores["All"]["predicted"],
scores["All"]["matched"],
round(scores["All"]["P"], 3),
round(scores["All"]["R"], 3),
round(scores["All"]["F1"], 3),
scores["All"]["Support"],
]
]
tt.print(result, header, style=tt.styles.markdown)
#!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os.path from pathlib import Path
from setuptools import setup from setuptools import find_packages, setup
def requirements(path): def parse_requirements_line(line):
assert os.path.exists(path), "Missing requirements {}.format(path)" """Special case for git requirements"""
with open(path) as f: if line.startswith("git+http"):
return list(map(str.strip, f.read().splitlines())) assert "@" in line, "Branch should be specified with suffix (ex: @master)"
assert (
"#egg=" in line
), "Package name should be specified with suffix (ex: #egg=kraken)"
package_name = line.split("#egg=")[-1]
return f"{package_name} @ {line}"
else:
return line
install_requires = requirements("requirements.txt") def parse_requirements():
path = Path(__file__).parent.resolve() / "requirements.txt"
assert path.exists(), f"Missing requirements: {path}"
return list(
map(parse_requirements_line, map(str.strip, path.read_text().splitlines()))
)
setup( setup(
name="Nerval", name="nerval",
version=open("VERSION").read(), version=open("VERSION").read(),
description="Tool to evaluate NER on noisy text.", description="Tool to evaluate NER on noisy text.",
author="Teklia", author="Teklia",
author_email="bmiret@teklia.com", author_email="contact@teklia.com",
packages=["nerval"], packages=find_packages(),
entry_points={"console_scripts": ["nerval=nerval.evaluate:main"]}, entry_points={"console_scripts": ["nerval=nerval.cli:main"]},
install_requires=install_requires, install_requires=parse_requirements(),
) )
# -*- coding: utf-8 -*-
from pathlib import Path
import pytest
FIXTURES = Path(__file__).parent / "fixtures"
@pytest.fixture
def fake_annot_bio():
return FIXTURES / "test_annot.bio"
@pytest.fixture
def fake_predict_bio():
return FIXTURES / "test_predict.bio"
@pytest.fixture
def empty_bio():
return FIXTURES / "test_empty.bio"
@pytest.fixture
def bad_bio():
return FIXTURES / "test_bad.bio"
@pytest.fixture
def bioeslu_bio():
return FIXTURES / "bioeslu.bio"
@pytest.fixture
def end_of_file_bio():
return FIXTURES / "end_of_file.bio"
@pytest.fixture
def nested_bio():
return FIXTURES / "test_nested.bio"
@pytest.fixture
def folder_bio():
return Path("test_folder")
@pytest.fixture()
def csv_file():
return Path("test_mapping_file.csv")
File moved
File moved
File moved
File moved
File moved
File moved
File moved
...@@ -2,21 +2,21 @@ ...@@ -2,21 +2,21 @@
import edlib import edlib
import pytest import pytest
fake_annot_original = "Gérard de Nerval was born in Paris in 1808 ."
fake_predict_original = "G*rard de *N*erval bo*rn in Paris in 1833 *."
expected_alignment = {
"query_aligned": "Gérard de -N-erval was bo-rn in Paris in 1808 -.",
"matched_aligned": "|.||||||||-|-||||||----||-|||||||||||||||||..|-|",
"target_aligned": "G*rard de *N*erval ----bo*rn in Paris in 1833 *.",
}
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_input, expected", "query,target",
[((fake_annot_original, fake_predict_original), expected_alignment)], [
(
"Gérard de Nerval was born in Paris in 1808 .",
"G*rard de *N*erval bo*rn in Paris in 1833 *.",
)
],
) )
def test_align(test_input, expected): def test_align(query, target):
a = edlib.align(*test_input, task="path") a = edlib.align(query, target, task="path")
result_alignment = edlib.getNiceAlignment(a, *test_input) result_alignment = edlib.getNiceAlignment(a, query, target)
assert result_alignment == expected assert result_alignment == {
"query_aligned": "Gérard de -N-erval was bo-rn in Paris in 1808 -.",
"matched_aligned": "|.||||||||-|-||||||----||-|||||||||||||||||..|-|",
"target_aligned": "G*rard de *N*erval ----bo*rn in Paris in 1833 *.",
}
...@@ -5,10 +5,6 @@ from nerval import evaluate ...@@ -5,10 +5,6 @@ from nerval import evaluate
THRESHOLD = 0.30 THRESHOLD = 0.30
fake_annot_aligned = "Gérard de -N-erval was bo-rn in Paris in 1808 -."
fake_predict_aligned = "G*rard de *N*erval ----bo*rn in Paris in 1833 *."
fake_string_nested = "Louis par la grâce de Dieu roy de France et de Navarre."
# fmt: off # fmt: off
fake_tags_aligned_nested_perfect = [ fake_tags_aligned_nested_perfect = [
...@@ -141,12 +137,6 @@ fake_annot_tags_aligned = [ ...@@ -141,12 +137,6 @@ fake_annot_tags_aligned = [
"O", "O",
] ]
expected_matches = {"All": 1, "PER": 1, "LOC": 0, "DAT": 0}
expected_matches_nested_perfect = {"All": 3, "PER": 1, "LOC": 2}
expected_matches_nested_false = {"All": 2, "PER": 1, "LOC": 1}
fake_annot_backtrack_boundary = "The red dragon"
fake_annot_tags_bk_boundary = [ fake_annot_tags_bk_boundary = [
"O", "O",
"O", "O",
...@@ -181,10 +171,6 @@ fake_predict_tags_bk_boundary = [ ...@@ -181,10 +171,6 @@ fake_predict_tags_bk_boundary = [
"I-PER", "I-PER",
] ]
expected_matches_bk_boundary = {"All": 0, "PER": 0}
fake_annot_backtrack_boundary_2 = "A red dragon"
fake_annot_tags_bk_boundary_2 = [ fake_annot_tags_bk_boundary_2 = [
"O", "O",
"O", "O",
...@@ -215,61 +201,59 @@ fake_predict_tags_bk_boundary_2 = [ ...@@ -215,61 +201,59 @@ fake_predict_tags_bk_boundary_2 = [
"I-PER", "I-PER",
] ]
expected_matches_bk_boundary_2 = {"All": 1, "PER": 1}
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test_input, expected", "test_input, expected",
[ [
( (
( (
fake_annot_aligned, "Gérard de -N-erval was bo-rn in Paris in 1808 -.",
fake_predict_aligned, "G*rard de *N*erval ----bo*rn in Paris in 1833 *.",
fake_annot_tags_aligned, fake_annot_tags_aligned,
fake_predict_tags_aligned, fake_predict_tags_aligned,
THRESHOLD, THRESHOLD,
), ),
expected_matches, {"All": 1, "PER": 1, "LOC": 0, "DAT": 0},
), ),
( (
( (
fake_string_nested, "Louis par la grâce de Dieu roy de France et de Navarre.",
fake_string_nested, "Louis par la grâce de Dieu roy de France et de Navarre.",
fake_tags_aligned_nested_perfect, fake_tags_aligned_nested_perfect,
fake_tags_aligned_nested_perfect, fake_tags_aligned_nested_perfect,
THRESHOLD, THRESHOLD,
), ),
expected_matches_nested_perfect, {"All": 3, "PER": 1, "LOC": 2},
), ),
( (
( (
fake_string_nested, "Louis par la grâce de Dieu roy de France et de Navarre.",
fake_string_nested, "Louis par la grâce de Dieu roy de France et de Navarre.",
fake_tags_aligned_nested_perfect, fake_tags_aligned_nested_perfect,
fake_tags_aligned_nested_false, fake_tags_aligned_nested_false,
THRESHOLD, THRESHOLD,
), ),
expected_matches_nested_false, {"All": 2, "PER": 1, "LOC": 1},
), ),
( (
( (
fake_annot_backtrack_boundary, "The red dragon",
fake_annot_backtrack_boundary, "The red dragon",
fake_annot_tags_bk_boundary, fake_annot_tags_bk_boundary,
fake_predict_tags_bk_boundary, fake_predict_tags_bk_boundary,
THRESHOLD, THRESHOLD,
), ),
expected_matches_bk_boundary, {"All": 0, "PER": 0},
), ),
( (
( (
fake_annot_backtrack_boundary_2, "A red dragon",
fake_annot_backtrack_boundary_2, "A red dragon",
fake_annot_tags_bk_boundary_2, fake_annot_tags_bk_boundary_2,
fake_predict_tags_bk_boundary_2, fake_predict_tags_bk_boundary_2,
THRESHOLD, THRESHOLD,
), ),
expected_matches_bk_boundary_2, {"All": 1, "PER": 1},
), ),
], ],
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment