Skip to content
Snippets Groups Projects
Commit fc146a8d authored by Eva Bardou's avatar Eva Bardou :frog: Committed by Yoann Schneider
Browse files

Convert NER prediction to BIO format

parent 65128721
No related branches found
No related tags found
1 merge request!325Convert NER prediction to BIO format
import logging
import re
from typing import Dict, List
from dan.utils import EntityType
logger = logging.getLogger(__name__)
def convert(text: str, ner_tokens: Dict[str, EntityType]) -> str:
# Mapping to find a starting token for an ending token efficiently
mapping_end_start: Dict[str, str] = {
entity_type.end: entity_type.start for entity_type in ner_tokens.values()
}
# Mapping to find the entity name for a starting token efficiently
mapping_start_name: Dict[str, str] = {
entity_type.start: name for name, entity_type in ner_tokens.items()
}
starting_tokens: List[str] = mapping_start_name.keys()
ending_tokens: List[str] = mapping_end_start.keys()
has_ending_tokens: bool = set(ending_tokens) != {
""
} # Whether ending tokens are used
# Spacing starting tokens and ending tokens (if necessary)
tokens_spacing: re.Pattern = re.compile(
r"([" + "".join([*starting_tokens, *ending_tokens]) + "])"
)
text: str = tokens_spacing.sub(r" \1 ", text)
iob: List[str] = [] # List of IOB formatted strings
entity_types: List[str] = [] # Encountered entity types
inside: bool = False # Whether we are inside an entity
for token in text.split():
# Encountering a starting token
if token in starting_tokens:
entity_types.append(token)
# Stopping any current entity type
inside = False
continue
# Encountering an ending token
elif has_ending_tokens and token in ending_tokens:
if not entity_types:
logger.warning(
f"Missing starting token for ending token {token}, skipping the entity"
)
continue
# Making sure this ending token closes the current entity
assert (
entity_types[-1] == mapping_end_start[token]
), f"Ending token {token} doesn't match the starting token {entity_types[-1]}"
# Removing the current entity from the queue as it is its end
entity_types.pop()
# If there is still entities in the queue, we continue in the parent one
# Else, we are not in any entity anymore
inside = bool(entity_types)
continue
# The token is not part of an entity
if not entity_types:
iob.append(f"{token} O")
continue
# The token is part of at least one entity
entity_name: str = mapping_start_name[entity_types[-1]]
if inside:
# Inside the same entity
iob.append(f"{token} I-{entity_name}")
continue
# Starting a new entity
iob.append(f"{token} B-{entity_name}")
inside = True
# Concatenating all formatted iob strings
return "\n".join(iob)
import logging
import pytest
from dan.bio import convert
from dan.utils import EntityType
ST_TEXT = """ⒶBryan B ⒷParis ⒸJanuary 1st, 1987
ⒶJoe J ⒷGrenoble ⒸAugust 24, 1995
ⒶHannah H ⒷLille ⒸSeptember 15, 2002"""
ST_ET_TEXT = """ⒶBryanⒷ and ⒶJoeⒷ will visit the ⒸEiffel TowerⒹ in ⒸParisⒹ next ⒺTuesdayⒻ.
ⒶHannahⒷ will visit the ⒸPlace ⒶCharles de GaulleⒷ étoileⒹ on ⒺWednesdayⒻ."""
def test_convert_with_error():
ner_tokens = {
"Person": EntityType(start="", end=""),
"Location": EntityType(start="", end=""),
}
with pytest.raises(
AssertionError, match="Ending token Ⓓ doesn't match the starting token Ⓐ"
):
convert("ⒶFredⒹ", ner_tokens)
def test_convert_with_warnings(caplog):
ner_tokens = {
"Person": EntityType(start="", end=""),
"Location": EntityType(start="", end=""),
}
assert convert("BryanⒷ and ⒶJoeⒷ will visit the Eiffel TowerⒹ", ner_tokens).split(
"\n"
) == [
"Bryan O",
"and O",
"Joe B-Person",
"will O",
"visit O",
"the O",
"Eiffel O",
"Tower O",
]
assert [(level, message) for _, level, message in caplog.record_tuples] == [
(
logging.WARNING,
"Missing starting token for ending token Ⓑ, skipping the entity",
),
(
logging.WARNING,
"Missing starting token for ending token Ⓓ, skipping the entity",
),
]
def test_convert_starting_tokens():
ner_tokens = {
"Person": EntityType(start=""),
"Location": EntityType(start=""),
"Date": EntityType(start=""),
}
assert convert(ST_TEXT, ner_tokens).split("\n") == [
"Bryan B-Person",
"B I-Person",
"Paris B-Location",
"January B-Date",
"1st, I-Date",
"1987 I-Date",
"Joe B-Person",
"J I-Person",
"Grenoble B-Location",
"August B-Date",
"24, I-Date",
"1995 I-Date",
"Hannah B-Person",
"H I-Person",
"Lille B-Location",
"September B-Date",
"15, I-Date",
"2002 I-Date",
]
def test_convert_starting_and_ending_tokens():
ner_tokens = {
"Person": EntityType(start="", end=""),
"Location": EntityType(start="", end=""),
"Date": EntityType(start="", end=""),
}
assert convert(ST_ET_TEXT, ner_tokens).split("\n") == [
"Bryan B-Person",
"and O",
"Joe B-Person",
"will O",
"visit O",
"the O",
"Eiffel B-Location",
"Tower I-Location",
"in O",
"Paris B-Location",
"next O",
"Tuesday B-Date",
". O",
"Hannah B-Person",
"will O",
"visit O",
"the O",
"Place B-Location",
"Charles B-Person",
"de I-Person",
"Gaulle I-Person",
"étoile I-Location",
"on O",
"Wednesday B-Date",
". O",
]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment