diff --git a/dan/bio.py b/dan/bio.py new file mode 100644 index 0000000000000000000000000000000000000000..d676d937cc652fa88c4d18919c38b9c53e3d1a39 --- /dev/null +++ b/dan/bio.py @@ -0,0 +1,85 @@ +import logging +import re +from typing import Dict, List + +from dan.utils import EntityType + +logger = logging.getLogger(__name__) + + +def convert(text: str, ner_tokens: Dict[str, EntityType]) -> str: + # Mapping to find a starting token for an ending token efficiently + mapping_end_start: Dict[str, str] = { + entity_type.end: entity_type.start for entity_type in ner_tokens.values() + } + # Mapping to find the entity name for a starting token efficiently + mapping_start_name: Dict[str, str] = { + entity_type.start: name for name, entity_type in ner_tokens.items() + } + + starting_tokens: List[str] = mapping_start_name.keys() + ending_tokens: List[str] = mapping_end_start.keys() + has_ending_tokens: bool = set(ending_tokens) != { + "" + } # Whether ending tokens are used + + # Spacing starting tokens and ending tokens (if necessary) + tokens_spacing: re.Pattern = re.compile( + r"([" + "".join([*starting_tokens, *ending_tokens]) + "])" + ) + text: str = tokens_spacing.sub(r" \1 ", text) + + iob: List[str] = [] # List of IOB formatted strings + entity_types: List[str] = [] # Encountered entity types + inside: bool = False # Whether we are inside an entity + for token in text.split(): + # Encountering a starting token + if token in starting_tokens: + entity_types.append(token) + + # Stopping any current entity type + inside = False + + continue + + # Encountering an ending token + elif has_ending_tokens and token in ending_tokens: + if not entity_types: + logger.warning( + f"Missing starting token for ending token {token}, skipping the entity" + ) + continue + + # Making sure this ending token closes the current entity + assert ( + entity_types[-1] == mapping_end_start[token] + ), f"Ending token {token} doesn't match the starting token {entity_types[-1]}" + + # Removing the current entity from the queue as it is its end + entity_types.pop() + + # If there is still entities in the queue, we continue in the parent one + # Else, we are not in any entity anymore + inside = bool(entity_types) + + continue + + # The token is not part of an entity + if not entity_types: + iob.append(f"{token} O") + continue + + # The token is part of at least one entity + entity_name: str = mapping_start_name[entity_types[-1]] + + if inside: + # Inside the same entity + iob.append(f"{token} I-{entity_name}") + continue + + # Starting a new entity + iob.append(f"{token} B-{entity_name}") + inside = True + + # Concatenating all formatted iob strings + return "\n".join(iob) diff --git a/tests/test_bio.py b/tests/test_bio.py new file mode 100644 index 0000000000000000000000000000000000000000..fce81268da39ed945b88e797a2a03cf227d585a3 --- /dev/null +++ b/tests/test_bio.py @@ -0,0 +1,120 @@ +import logging + +import pytest + +from dan.bio import convert +from dan.utils import EntityType + +ST_TEXT = """ⒶBryan B ⒷParis ⒸJanuary 1st, 1987 +ⒶJoe J ⒷGrenoble ⒸAugust 24, 1995 +ⒶHannah H ⒷLille ⒸSeptember 15, 2002""" + +ST_ET_TEXT = """ⒶBryanⒷ and ⒶJoeⒷ will visit the ⒸEiffel TowerⒹ in ⒸParisⒹ next ⒺTuesdayⒻ. +ⒶHannahⒷ will visit the ⒸPlace ⒶCharles de GaulleⒷ étoileⒹ on ⒺWednesdayⒻ.""" + + +def test_convert_with_error(): + ner_tokens = { + "Person": EntityType(start="Ⓐ", end="Ⓑ"), + "Location": EntityType(start="Ⓒ", end="Ⓓ"), + } + + with pytest.raises( + AssertionError, match="Ending token Ⓓ doesn't match the starting token Ⓐ" + ): + convert("ⒶFredⒹ", ner_tokens) + + +def test_convert_with_warnings(caplog): + ner_tokens = { + "Person": EntityType(start="Ⓐ", end="Ⓑ"), + "Location": EntityType(start="Ⓒ", end="Ⓓ"), + } + + assert convert("BryanⒷ and ⒶJoeⒷ will visit the Eiffel TowerⒹ", ner_tokens).split( + "\n" + ) == [ + "Bryan O", + "and O", + "Joe B-Person", + "will O", + "visit O", + "the O", + "Eiffel O", + "Tower O", + ] + assert [(level, message) for _, level, message in caplog.record_tuples] == [ + ( + logging.WARNING, + "Missing starting token for ending token Ⓑ, skipping the entity", + ), + ( + logging.WARNING, + "Missing starting token for ending token Ⓓ, skipping the entity", + ), + ] + + +def test_convert_starting_tokens(): + ner_tokens = { + "Person": EntityType(start="Ⓐ"), + "Location": EntityType(start="Ⓑ"), + "Date": EntityType(start="Ⓒ"), + } + + assert convert(ST_TEXT, ner_tokens).split("\n") == [ + "Bryan B-Person", + "B I-Person", + "Paris B-Location", + "January B-Date", + "1st, I-Date", + "1987 I-Date", + "Joe B-Person", + "J I-Person", + "Grenoble B-Location", + "August B-Date", + "24, I-Date", + "1995 I-Date", + "Hannah B-Person", + "H I-Person", + "Lille B-Location", + "September B-Date", + "15, I-Date", + "2002 I-Date", + ] + + +def test_convert_starting_and_ending_tokens(): + ner_tokens = { + "Person": EntityType(start="Ⓐ", end="Ⓑ"), + "Location": EntityType(start="Ⓒ", end="Ⓓ"), + "Date": EntityType(start="Ⓔ", end="Ⓕ"), + } + + assert convert(ST_ET_TEXT, ner_tokens).split("\n") == [ + "Bryan B-Person", + "and O", + "Joe B-Person", + "will O", + "visit O", + "the O", + "Eiffel B-Location", + "Tower I-Location", + "in O", + "Paris B-Location", + "next O", + "Tuesday B-Date", + ". O", + "Hannah B-Person", + "will O", + "visit O", + "the O", + "Place B-Location", + "Charles B-Person", + "de I-Person", + "Gaulle I-Person", + "étoile I-Location", + "on O", + "Wednesday B-Date", + ". O", + ]