Skip to content
Snippets Groups Projects
Commit 04326cc2 authored by Manon Blanco's avatar Manon Blanco
Browse files

Merge branch 'use-regex-for-ner-parsing' into 'main'

Use regex to parse NER entities

Closes #202

See merge request !392
parents e4fd588d 5fe46062
No related branches found
No related tags found
1 merge request!392Use regex to parse NER entities
......@@ -2,7 +2,6 @@
import logging
import re
from enum import Enum
from itertools import pairwise
from typing import Dict, List, Tuple
import cv2
......@@ -28,78 +27,30 @@ class Level(str, Enum):
def parse_delimiters(delimiters: List[str]) -> re.Pattern:
return re.compile(r"|".join(delimiters))
return re.compile(rf"[^{'|'.join(delimiters)}]+")
def build_ner_indices(
text: str, tokens: Dict[str, EntityType]
) -> List[Tuple[int, int]]:
"""
Compute the position of NER tokens in the text and return a list of indices.
:param text: list of characters.
:param tokens: NER tokens used.
Returns a list of indices where tokens are located.
"""
start_tokens, end_tokens = zip(*list(tokens.values()))
end_tokens = list(filter(bool, end_tokens))
if len(end_tokens):
assert len(start_tokens) == len(
end_tokens
), "You don't have the same number of starting tokens and ending tokens"
return [
[pos_start, pos_end] for pos_start, pos_end in zip(start_tokens, end_tokens)
]
return list(
pairwise(
[pos for pos, char in enumerate(text) if char in start_tokens] + [None]
)
)
def compute_offsets_by_level(
level: Level, text_list: List[str], indices: List[Tuple[int, int]]
):
def compute_offsets_by_level(full_text: str, level: Level, text_list: List[str]):
"""
Compute and return the list of offset between each text part.
:param full_text: predicted text.
:param level: Level to use from [char, word, line, ner].
:param text_list: list of text to use.
:param indices: list of indices where tokens are located for NER computation.
Returns a list of offsets.
"""
if level == Level.NER:
return (
[
current - next_token
for (_, next_token), (current, _) in pairwise(indices)
]
# Pad the list to match the length of the text list
+ [0]
)
return [int(level != Level.Char)] * len(text_list)
# offsets[idx] = number of characters between text_list[idx-1] and text_list[idx]
offsets = [int(level != Level.Char)] * (len(text_list) - 1)
if level == Level.NER:
# Start after the first entity
cursor = len(text_list[0])
for idx, split in enumerate(text_list[1:]):
# Number of characters between this entity and the previous one
offsets[idx] = full_text[cursor:].index(split)
cursor += offsets[idx] + len(split)
def compute_prob_by_ner(
characters: str, probabilities: List[float], indices: List[Tuple[int, int]]
) -> Tuple[List[str], List[np.float64]]:
"""
Split text and confidences using indices and return a list of average confidence scores.
:param characters: list of characters.
:param probabilities: list of character probabilities.
:param indices: list of indices where tokens are located.
Returns a list confidence scores.
"""
return zip(
*[
(
characters[current:next_token],
np.mean(probabilities[current:next_token]),
)
for current, next_token in indices
]
)
# Last offset is not used, padded with a 0 to match the length of text_list
return offsets + [0]
def compute_prob_by_separator(
......@@ -113,13 +64,12 @@ def compute_prob_by_separator(
Returns a list confidence scores.
"""
# match anything except separators, get start and end index
pattern = re.compile(f"[^{separator.pattern}]+")
matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)]
matches = [(m.start(), m.end()) for m in separator.finditer(characters)]
# Iterate over text pieces and compute mean confidence
probs = [np.mean(probabilities[start:end]) for (start, end) in matches]
texts = [characters[start:end] for (start, end) in matches]
return texts, probs
return [characters[start:end] for (start, end) in matches], [
np.mean(probabilities[start:end]) for (start, end) in matches
]
def split_text(
......@@ -127,40 +77,36 @@ def split_text(
level: Level,
word_separators: re.Pattern,
line_separators: re.Pattern,
tokens: Dict[str, EntityType],
tokens_separators: re.Pattern | None = None,
) -> Tuple[List[str], List[int]]:
"""
Split text into a list of characters, word, or lines.
:param text: Text prediction from DAN
:param level: Level to visualize from [char, word, line, ner]
:param word_separators: List of word separators
:param line_separators: List of line separators
:param tokens: NER tokens used
:param word_separators: Pattern used to find words
:param line_separators: Pattern used to find lines
:param tokens_separators: Pattern used to find NER entities
"""
indices = []
match level:
case Level.Char:
text_split = list(text)
# split into words
case Level.Word:
text_split = re.split(word_separators, text)
text_split = word_separators.split(text)
# split into lines
case Level.Line:
text_split = re.split(line_separators, text)
text_split = line_separators.split(text)
# split into entities
case Level.NER:
if not tokens:
if not tokens_separators:
logger.error("Cannot compute NER level: tokens not found")
return [], []
indices = build_ner_indices(text, tokens)
text_split = [text[current:next_token] for current, next_token in indices]
text_split = tokens_separators.findall(text)
case _:
logger.error(f"Level should be either {list(map(str, Level))}")
return [], []
return text_split, compute_offsets_by_level(level, text_split, indices)
return text_split, compute_offsets_by_level(text, level, text_split)
def split_text_and_confidences(
......@@ -169,19 +115,17 @@ def split_text_and_confidences(
level: Level,
word_separators: re.Pattern,
line_separators: re.Pattern,
tokens: Dict[str, EntityType],
tokens_separators: re.Pattern | None = None,
) -> Tuple[List[str], List[np.float64], List[int]]:
"""
Split text into a list of characters, words or lines with corresponding confidences scores
:param text: Text prediction from DAN
:param confidences: Character confidences
:param level: Level to visualize from [char, word, line, ner]
:param word_separators: List of word separators
:param line_separators: List of line separators
:param tokens: NER tokens used
:param word_separators: Pattern used to find words
:param line_separators: Pattern used to find lines
:param tokens_separators: Pattern used to find NER entities
"""
indices = []
match level:
case Level.Char:
texts = list(text)
......@@ -194,15 +138,13 @@ def split_text_and_confidences(
text, confidences, line_separators
)
case Level.NER:
if not tokens:
if not tokens_separators:
logger.error("Cannot compute NER level: tokens not found")
return [], [], []
indices = build_ner_indices(text, tokens)
if not indices:
return [], [], []
texts, confidences = compute_prob_by_ner(text, confidences, indices)
texts, confidences = compute_prob_by_separator(
text, confidences, tokens_separators
)
case _:
logger.error(f"Level should be either {list(map(str, Level))}")
return [], [], []
......@@ -210,7 +152,7 @@ def split_text_and_confidences(
return (
texts,
[np.around(num, 2) for num in confidences],
compute_offsets_by_level(level, texts, indices),
compute_offsets_by_level(text, level, texts),
)
......@@ -224,7 +166,7 @@ def get_predicted_polygons_with_confidence(
max_object_height: int = 50,
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
tokens: Dict[str, EntityType] = {},
tokens_separators: re.Pattern | None = None,
) -> List[dict]:
"""
Returns the polygons of each object of the current prediction
......@@ -235,13 +177,13 @@ def get_predicted_polygons_with_confidence(
:param height: Original image height
:param width: Original image width
:param max_object_height: Maximum height of predicted objects.
:param word_separators: List of word separators
:param line_separators: List of line separators
:param tokens: NER tokens used
:param word_separators: Pattern used to find words
:param line_separators: Pattern used to find lines
:param tokens_separators: Pattern used to find NER entities
"""
# Split text into characters, words or lines
text_list, confidence_list, offsets = split_text_and_confidences(
text, confidences, level, word_separators, line_separators, tokens
text, confidences, level, word_separators, line_separators, tokens_separators
)
max_value = weights.sum(0).max()
......
......@@ -26,6 +26,7 @@ from dan.utils import (
ind_to_token,
list_to_batches,
pad_images,
parse_tokens_pattern,
read_image,
)
......@@ -361,6 +362,9 @@ def process_batch(
word_separators = parse_delimiters(word_separators)
line_separators = parse_delimiters(line_separators)
# NER Entities separators
ner_separators = parse_tokens_pattern(tokens.values())
# Predict
logger.info("Predicting...")
prediction = dan_model.predict(
......@@ -407,7 +411,7 @@ def process_batch(
level,
word_separators,
line_separators,
tokens,
ner_separators,
)
for text, conf in zip(texts, confidences):
......@@ -430,7 +434,7 @@ def process_batch(
color_map=color_map,
word_separators=word_separators,
line_separators=line_separators,
tokens=tokens,
tokens_separators=ner_separators,
display_polygons=predict_objects,
max_object_height=max_object_height,
outname=gif_filename,
......
# -*- coding: utf-8 -*-
import json
import re
from argparse import ArgumentTypeError
from itertools import islice
from operator import attrgetter
......@@ -148,6 +149,19 @@ def parse_tokens(filename: str) -> Dict[str, EntityType]:
return tokens
def parse_tokens_pattern(tokens: list[EntityType]) -> re.Pattern[str]:
starting_tokens = "".join(token.start for token in tokens)
# Check if there are end tokens
if all(token.end for token in tokens):
# Matches a starting token, the corresponding ending token, and any text in between
return re.compile(
r"|".join(f"({token.start}.*?{token.end})" for token in tokens)
)
# Matches a starting token then any character that is not a starting token
return re.compile(rf"([{starting_tokens}][^{starting_tokens}]*)")
def read_yaml(yaml_path: str) -> Dict:
"""
Read YAML tokens file.
......
import pytest
from dan.ocr.predict.attention import (
Level,
parse_delimiters,
split_text_and_confidences,
)
from dan.utils import EntityType, parse_tokens_pattern
@pytest.mark.parametrize(
(
"text",
"confidence",
"level",
"tokens",
"split_text",
"mean_confidences",
"expected_offsets",
),
[
# level: char
(
"Tok",
[0.1, 0.2, 0.3],
Level.Char,
None,
["T", "o", "k"],
[0.1, 0.2, 0.3],
[0, 0, 0],
),
# level: word
(
"Lo ve\nTokyo",
[0.1, 0.1, 0.2, 0.3, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5],
Level.Word,
None,
["Lo", "ve", "Tokyo"],
[0.1, 0.3, 0.5],
[1, 1, 0],
),
# level: line
(
"Love\nTokyo",
[0.1, 0.1, 0.1, 0.1, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3],
Level.Line,
None,
["Love", "Tokyo"],
[0.1, 0.3],
[1, 0],
),
# level: NER (no end tokens)
(
"ⒶLove ⒷTokyo",
[0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5],
Level.NER,
[EntityType(start=""), EntityType(start="")],
["ⒶLove ", "ⒷTokyo"],
[0.2, 0.48],
[0, 0],
),
# level: NER (with end tokens)
(
"ⓐLoveⒶ ⓑTokyoⒷ",
[0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.6, 0.6, 0.6, 0.6, 0.6, 0.7],
Level.NER,
[EntityType(start="", end=""), EntityType(start="", end="")],
["ⓐLoveⒶ", "ⓑTokyoⒷ"],
[0.2, 0.6],
[1, 0],
),
# level: NER (no end tokens, no space)
(
"ⒶLoveⒷTokyo",
[0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.4, 0.4, 0.4, 0.4],
Level.NER,
[EntityType(start=""), EntityType(start="")],
["ⒶLove", "ⒷTokyo"],
[0.18, 0.38],
[0, 0],
),
# level: NER (with end tokens, no space)
(
"ⓐLoveⒶⓑTokyoⒷ",
[0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5, 0.6],
Level.NER,
[EntityType(start="", end=""), EntityType(start="", end="")],
["ⓐLoveⒶ", "ⓑTokyoⒷ"],
[0.2, 0.5],
[0, 0],
),
],
)
def test_split_text_and_confidences(
text: str,
confidence: list[float],
level: Level,
tokens: list[EntityType] | None,
split_text: list[str],
mean_confidences: list[list[float]],
expected_offsets: list[int],
):
texts, averages, offsets = split_text_and_confidences(
text=text,
confidences=confidence,
level=level,
word_separators=parse_delimiters([" ", "\n"]),
line_separators=parse_delimiters(["\n"]),
tokens_separators=parse_tokens_pattern(tokens) if tokens else None,
)
assert texts == split_text
assert averages == mean_confidences
assert offsets == expected_offsets
import pytest
import yaml
from dan.utils import parse_tokens
from dan.utils import EntityType, parse_tokens, parse_tokens_pattern
@pytest.mark.parametrize(
......@@ -37,3 +37,26 @@ def test_parse_tokens_errors(tmp_path, tokens, error_msg):
with pytest.raises(AssertionError, match=error_msg):
parse_tokens(tokens_path)
@pytest.mark.parametrize(
("pattern", "entity_types"),
[
# No end tokens
(
r"([ⒶⒷⒸ][^ⒶⒷⒸ]*)",
[EntityType(start=""), EntityType(start=""), EntityType(start="")],
),
# With end tokens
(
r"(ⓐ.*?Ⓐ)|(ⓑ.*?Ⓑ)|(ⓒ.*?Ⓒ)",
[
EntityType(start="", end=""),
EntityType(start="", end=""),
EntityType(start="", end=""),
],
),
],
)
def test_parse_tokens_pattern(pattern: str, entity_types: list[EntityType]):
assert parse_tokens_pattern(entity_types).pattern == pattern
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment