Skip to content
Snippets Groups Projects

Convert NER prediction to BIO format

Merged Eva Bardou requested to merge convert-to-BIO into main
All threads resolved!
2 files
+ 147
11
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 27
11
import logging
from operator import attrgetter
from typing import Dict, List
from typing import Dict, List, Optional
from dan.utils import EntityType
logger = logging.getLogger(__name__)
def find_starting_token(ner_tokens, ending_token):
def find_starting_token(
ner_tokens: Dict[str, EntityType], ending_token: str
) -> Optional[str]:
for entity_type in ner_tokens.values():
if entity_type.end == ending_token:
return entity_type.start
def find_entity_name(ner_tokens, starting_token):
def find_entity_name(
ner_tokens: Dict[str, EntityType], starting_token: str
) -> Optional[str]:
for name, entity_type in ner_tokens.items():
if entity_type.start == starting_token:
return name
@@ -23,10 +30,12 @@ def convert(text: str, ner_tokens: Dict[str, EntityType]) -> str:
text = text.replace(starting_token, f" {starting_token} ")
ending_tokens: List[str] = list(map(attrgetter("end"), ner_tokens.values()))
for ending_token in ending_tokens:
text = text.replace(ending_token, f" {ending_token} ")
has_ending_tokens: bool = bool(len(ending_tokens))
has_ending_tokens: bool = set(ending_tokens) != {
""
} # Whether ending tokens are used
if has_ending_tokens:
for ending_token in ending_tokens:
text = text.replace(ending_token, f" {ending_token} ")
iob: List[str] = [] # List of IOB formatted strings
entity_types: List[str] = [] # Encountered entity types
@@ -37,15 +46,22 @@ def convert(text: str, ner_tokens: Dict[str, EntityType]) -> str:
entity_types.append(token)
# Stopping any current entity type
if has_ending_tokens:
inside = False
inside = False
continue
# Encountering an ending token
elif has_ending_tokens and token in ending_tokens:
if not entity_types:
logger.warning(
f"Missing starting token for ending token {token}, skipping the entity"
)
continue
# Making sure this ending token closes the current entity
assert entity_types[-1] == find_starting_token(ner_tokens, token)
assert entity_types[-1] == find_starting_token(
ner_tokens, token
), f"Ending token {token} doesn't match the starting token {entity_types[-1]}"
# Removing the current entity from the queue as it is its end
entity_types.pop()
@@ -62,7 +78,7 @@ def convert(text: str, ner_tokens: Dict[str, EntityType]) -> str:
continue
# The token is part of at least one entity
entity_name = find_entity_name(ner_tokens, entity_types[-1])
entity_name: str = find_entity_name(ner_tokens, entity_types[-1])
if inside:
# Inside the same entity
Loading