Skip to content
Snippets Groups Projects

Add threshold as option

Merged Thibault Lavigne requested to merge add_threshold_as_option into master
All threads resolved!
4 files
+ 66
13
Compare changes
  • Side-by-side
  • Inline
Files
4
+ 34
6
@@ -9,9 +9,10 @@ import editdistance
import edlib
import termtables as tt
THRESHOLD = 0.30
NOT_ENTITY_TAG = "O"
THRESHOLD = 0.30
def get_type_label(label: str) -> str:
"""Return the type (tag) of a label
@@ -189,7 +190,11 @@ def look_for_further_entity_part(index, tag, characters, labels):
def compute_matches(
annotation: str, prediction: str, labels_annot: list, labels_predict: list
annotation: str,
prediction: str,
labels_annot: list,
labels_predict: list,
threshold: int,
) -> dict:
"""Compute prediction score from annotation string to prediction string.
@@ -274,7+279,7 @@
):
j -= 1
if (
"B" in labels_predict[j]
and get_type_label(labels_predict[j]) == tag_ref
and j not in visited_predict
@@ -324,7+329,7 @@
score = (
1
if editdistance.eval(entity_ref, entity_compar) / len_entity
< THRESHOLD
<= threshold
else 0
)
entity_count[last_tag] = entity_count.get(last_tag, 0) + score
@@ -454,7 +459,7 @@ def print_results(scores: dict):
tt.print(results, header, style=tt.styles.markdown)
def run(annotation: str, prediction: str) -> dict:
def run(annotation: str, prediction: str, threshold: int) -> dict:
"""Compute recall and precision for each entity type found in annotation and/or prediction.
Each measure is given at document level, global score is a micro-average across entity types.
@@ -486,7 +491,11 @@ def run(annotation: str, prediction: str) -> dict:
# Get nb match
matches = compute_matches(
annot_aligned, predict_aligned, labels_annot_aligned, labels_predict_aligned
annot_aligned,
predict_aligned,
labels_annot_aligned,
labels_predict_aligned,
threshold,
)
# Compute scores
@@ -498,6 +507,17 @@ def run(annotation: str, prediction: str) -> dict:
return scores
def threshold_float_type(arg):
"""Type function for argparse."""
try:
f = float(arg)
except ValueError:
raise argparse.ArgumentTypeError("Must be a floating point number.")
if f < 0 or f > 1:
raise argparse.ArgumentTypeError("Must be between 0 and 1.")
return f
def main():
"""Get arguments and run."""
@@ -510,9 +530,17 @@ def main():
parser.add_argument(
"-p", "--predict", help="Prediction in BIO format.", required=True
)
parser.add_argument(
"-t",
"--threshold",
help="Set a distance threshold for the match between gold and predicted entity.",
required=False,
default=THRESHOLD,
type=threshold_float_type,
)
args = parser.parse_args()
run(args.annot, args.predict)
run(args.annot, args.predict, args.threshold)
if __name__ == "__main__":
Loading