diff --git a/dan/ocr/predict/attention.py b/dan/ocr/predict/attention.py index b2c273eddcf916a876e4667887acab40a04273e7..978584cb4b0f4350b4c9646eeb23591c0f743e14 100644 --- a/dan/ocr/predict/attention.py +++ b/dan/ocr/predict/attention.py @@ -2,7 +2,6 @@ import logging import re from enum import Enum -from itertools import pairwise from typing import Dict, List, Tuple import cv2 @@ -28,78 +27,30 @@ class Level(str, Enum): def parse_delimiters(delimiters: List[str]) -> re.Pattern: - return re.compile(r"|".join(delimiters)) + return re.compile(rf"[^{'|'.join(delimiters)}]+") -def build_ner_indices( - text: str, tokens: Dict[str, EntityType] -) -> List[Tuple[int, int]]: - """ - Compute the position of NER tokens in the text and return a list of indices. - :param text: list of characters. - :param tokens: NER tokens used. - Returns a list of indices where tokens are located. - """ - start_tokens, end_tokens = zip(*list(tokens.values())) - end_tokens = list(filter(bool, end_tokens)) - - if len(end_tokens): - assert len(start_tokens) == len( - end_tokens - ), "You don't have the same number of starting tokens and ending tokens" - return [ - [pos_start, pos_end] for pos_start, pos_end in zip(start_tokens, end_tokens) - ] - - return list( - pairwise( - [pos for pos, char in enumerate(text) if char in start_tokens] + [None] - ) - ) - - -def compute_offsets_by_level( - level: Level, text_list: List[str], indices: List[Tuple[int, int]] -): +def compute_offsets_by_level(full_text: str, level: Level, text_list: List[str]): """ Compute and return the list of offset between each text part. + :param full_text: predicted text. :param level: Level to use from [char, word, line, ner]. :param text_list: list of text to use. - :param indices: list of indices where tokens are located for NER computation. Returns a list of offsets. """ - if level == Level.NER: - return ( - [ - current - next_token - for (_, next_token), (current, _) in pairwise(indices) - ] - # Pad the list to match the length of the text list - + [0] - ) - - return [int(level != Level.Char)] * len(text_list) + # offsets[idx] = number of characters between text_list[idx-1] and text_list[idx] + offsets = [int(level != Level.Char)] * (len(text_list) - 1) + if level == Level.NER: + # Start after the first entity + cursor = len(text_list[0]) + for idx, split in enumerate(text_list[1:]): + # Number of characters between this entity and the previous one + offsets[idx] = full_text[cursor:].index(split) + cursor += offsets[idx] + len(split) -def compute_prob_by_ner( - characters: str, probabilities: List[float], indices: List[Tuple[int, int]] -) -> Tuple[List[str], List[np.float64]]: - """ - Split text and confidences using indices and return a list of average confidence scores. - :param characters: list of characters. - :param probabilities: list of character probabilities. - :param indices: list of indices where tokens are located. - Returns a list confidence scores. - """ - return zip( - *[ - ( - characters[current:next_token], - np.mean(probabilities[current:next_token]), - ) - for current, next_token in indices - ] - ) + # Last offset is not used, padded with a 0 to match the length of text_list + return offsets + [0] def compute_prob_by_separator( @@ -113,13 +64,12 @@ def compute_prob_by_separator( Returns a list confidence scores. """ # match anything except separators, get start and end index - pattern = re.compile(f"[^{separator.pattern}]+") - matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)] + matches = [(m.start(), m.end()) for m in separator.finditer(characters)] # Iterate over text pieces and compute mean confidence - probs = [np.mean(probabilities[start:end]) for (start, end) in matches] - texts = [characters[start:end] for (start, end) in matches] - return texts, probs + return [characters[start:end] for (start, end) in matches], [ + np.mean(probabilities[start:end]) for (start, end) in matches + ] def split_text( @@ -127,40 +77,36 @@ def split_text( level: Level, word_separators: re.Pattern, line_separators: re.Pattern, - tokens: Dict[str, EntityType], + tokens_separators: re.Pattern | None = None, ) -> Tuple[List[str], List[int]]: """ Split text into a list of characters, word, or lines. :param text: Text prediction from DAN :param level: Level to visualize from [char, word, line, ner] - :param word_separators: List of word separators - :param line_separators: List of line separators - :param tokens: NER tokens used + :param word_separators: Pattern used to find words + :param line_separators: Pattern used to find lines + :param tokens_separators: Pattern used to find NER entities """ - indices = [] - match level: case Level.Char: text_split = list(text) # split into words case Level.Word: - text_split = re.split(word_separators, text) + text_split = word_separators.split(text) # split into lines case Level.Line: - text_split = re.split(line_separators, text) + text_split = line_separators.split(text) # split into entities case Level.NER: - if not tokens: + if not tokens_separators: logger.error("Cannot compute NER level: tokens not found") return [], [] - - indices = build_ner_indices(text, tokens) - text_split = [text[current:next_token] for current, next_token in indices] + text_split = tokens_separators.findall(text) case _: logger.error(f"Level should be either {list(map(str, Level))}") return [], [] - return text_split, compute_offsets_by_level(level, text_split, indices) + return text_split, compute_offsets_by_level(text, level, text_split) def split_text_and_confidences( @@ -169,19 +115,17 @@ def split_text_and_confidences( level: Level, word_separators: re.Pattern, line_separators: re.Pattern, - tokens: Dict[str, EntityType], + tokens_separators: re.Pattern | None = None, ) -> Tuple[List[str], List[np.float64], List[int]]: """ Split text into a list of characters, words or lines with corresponding confidences scores :param text: Text prediction from DAN :param confidences: Character confidences :param level: Level to visualize from [char, word, line, ner] - :param word_separators: List of word separators - :param line_separators: List of line separators - :param tokens: NER tokens used + :param word_separators: Pattern used to find words + :param line_separators: Pattern used to find lines + :param tokens_separators: Pattern used to find NER entities """ - indices = [] - match level: case Level.Char: texts = list(text) @@ -194,15 +138,13 @@ def split_text_and_confidences( text, confidences, line_separators ) case Level.NER: - if not tokens: + if not tokens_separators: logger.error("Cannot compute NER level: tokens not found") return [], [], [] - indices = build_ner_indices(text, tokens) - if not indices: - return [], [], [] - - texts, confidences = compute_prob_by_ner(text, confidences, indices) + texts, confidences = compute_prob_by_separator( + text, confidences, tokens_separators + ) case _: logger.error(f"Level should be either {list(map(str, Level))}") return [], [], [] @@ -210,7 +152,7 @@ def split_text_and_confidences( return ( texts, [np.around(num, 2) for num in confidences], - compute_offsets_by_level(level, texts, indices), + compute_offsets_by_level(text, level, texts), ) @@ -224,7 +166,7 @@ def get_predicted_polygons_with_confidence( max_object_height: int = 50, word_separators: re.Pattern = parse_delimiters(["\n", " "]), line_separators: re.Pattern = parse_delimiters(["\n"]), - tokens: Dict[str, EntityType] = {}, + tokens_separators: re.Pattern | None = None, ) -> List[dict]: """ Returns the polygons of each object of the current prediction @@ -235,13 +177,13 @@ def get_predicted_polygons_with_confidence( :param height: Original image height :param width: Original image width :param max_object_height: Maximum height of predicted objects. - :param word_separators: List of word separators - :param line_separators: List of line separators - :param tokens: NER tokens used + :param word_separators: Pattern used to find words + :param line_separators: Pattern used to find lines + :param tokens_separators: Pattern used to find NER entities """ # Split text into characters, words or lines text_list, confidence_list, offsets = split_text_and_confidences( - text, confidences, level, word_separators, line_separators, tokens + text, confidences, level, word_separators, line_separators, tokens_separators ) max_value = weights.sum(0).max() diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index d0357a7067478239d5125fa4ffb7e2b1d6c72dad..ae23a3f99c7db2093d73957bb39f9e543e6a2719 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -26,6 +26,7 @@ from dan.utils import ( ind_to_token, list_to_batches, pad_images, + parse_tokens_pattern, read_image, ) @@ -361,6 +362,9 @@ def process_batch( word_separators = parse_delimiters(word_separators) line_separators = parse_delimiters(line_separators) + # NER Entities separators + ner_separators = parse_tokens_pattern(tokens.values()) + # Predict logger.info("Predicting...") prediction = dan_model.predict( @@ -407,7 +411,7 @@ def process_batch( level, word_separators, line_separators, - tokens, + ner_separators, ) for text, conf in zip(texts, confidences): @@ -430,7 +434,7 @@ def process_batch( color_map=color_map, word_separators=word_separators, line_separators=line_separators, - tokens=tokens, + tokens_separators=ner_separators, display_polygons=predict_objects, max_object_height=max_object_height, outname=gif_filename, diff --git a/dan/utils.py b/dan/utils.py index 3c5cfe98c7070d3fc1d5c052588b6e5e701091fa..b584c60841853f5034192bed4f360b1d8c54fdf7 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import json +import re from argparse import ArgumentTypeError from itertools import islice from operator import attrgetter @@ -148,6 +149,19 @@ def parse_tokens(filename: str) -> Dict[str, EntityType]: return tokens +def parse_tokens_pattern(tokens: list[EntityType]) -> re.Pattern[str]: + starting_tokens = "".join(token.start for token in tokens) + # Check if there are end tokens + if all(token.end for token in tokens): + # Matches a starting token, the corresponding ending token, and any text in between + return re.compile( + r"|".join(f"({token.start}.*?{token.end})" for token in tokens) + ) + + # Matches a starting token then any character that is not a starting token + return re.compile(rf"([{starting_tokens}][^{starting_tokens}]*)") + + def read_yaml(yaml_path: str) -> Dict: """ Read YAML tokens file. diff --git a/tests/test_attention.py b/tests/test_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..a57cf2796a4dc0a237f4485edfb9abab962c7554 --- /dev/null +++ b/tests/test_attention.py @@ -0,0 +1,114 @@ +import pytest + +from dan.ocr.predict.attention import ( + Level, + parse_delimiters, + split_text_and_confidences, +) +from dan.utils import EntityType, parse_tokens_pattern + + +@pytest.mark.parametrize( + ( + "text", + "confidence", + "level", + "tokens", + "split_text", + "mean_confidences", + "expected_offsets", + ), + [ + # level: char + ( + "Tok", + [0.1, 0.2, 0.3], + Level.Char, + None, + ["T", "o", "k"], + [0.1, 0.2, 0.3], + [0, 0, 0], + ), + # level: word + ( + "Lo ve\nTokyo", + [0.1, 0.1, 0.2, 0.3, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5], + Level.Word, + None, + ["Lo", "ve", "Tokyo"], + [0.1, 0.3, 0.5], + [1, 1, 0], + ), + # level: line + ( + "Love\nTokyo", + [0.1, 0.1, 0.1, 0.1, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3], + Level.Line, + None, + ["Love", "Tokyo"], + [0.1, 0.3], + [1, 0], + ), + # level: NER (no end tokens) + ( + "â’¶Love â’·Tokyo", + [0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5], + Level.NER, + [EntityType(start="â’¶"), EntityType(start="â’·")], + ["â’¶Love ", "â’·Tokyo"], + [0.2, 0.48], + [0, 0], + ), + # level: NER (with end tokens) + ( + "â“Loveâ’¶ â“‘Tokyoâ’·", + [0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.6, 0.6, 0.6, 0.6, 0.6, 0.7], + Level.NER, + [EntityType(start="â“", end="â’¶"), EntityType(start="â“‘", end="â’·")], + ["â“Loveâ’¶", "â“‘Tokyoâ’·"], + [0.2, 0.6], + [1, 0], + ), + # level: NER (no end tokens, no space) + ( + "â’¶Loveâ’·Tokyo", + [0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.4, 0.4, 0.4, 0.4], + Level.NER, + [EntityType(start="â’¶"), EntityType(start="â’·")], + ["â’¶Love", "â’·Tokyo"], + [0.18, 0.38], + [0, 0], + ), + # level: NER (with end tokens, no space) + ( + "â“Loveâ’¶â“‘Tokyoâ’·", + [0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.4, 0.5, 0.5, 0.5, 0.5, 0.5, 0.6], + Level.NER, + [EntityType(start="â“", end="â’¶"), EntityType(start="â“‘", end="â’·")], + ["â“Loveâ’¶", "â“‘Tokyoâ’·"], + [0.2, 0.5], + [0, 0], + ), + ], +) +def test_split_text_and_confidences( + text: str, + confidence: list[float], + level: Level, + tokens: list[EntityType] | None, + split_text: list[str], + mean_confidences: list[list[float]], + expected_offsets: list[int], +): + texts, averages, offsets = split_text_and_confidences( + text=text, + confidences=confidence, + level=level, + word_separators=parse_delimiters([" ", "\n"]), + line_separators=parse_delimiters(["\n"]), + tokens_separators=parse_tokens_pattern(tokens) if tokens else None, + ) + + assert texts == split_text + assert averages == mean_confidences + assert offsets == expected_offsets diff --git a/tests/test_utils.py b/tests/test_utils.py index ecaf3bd74a8cd7c2f5b779518740e3953ae7d355..fcb3ecf3dc8003d943024c9afde1941ea34afff9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ import pytest import yaml -from dan.utils import parse_tokens +from dan.utils import EntityType, parse_tokens, parse_tokens_pattern @pytest.mark.parametrize( @@ -37,3 +37,26 @@ def test_parse_tokens_errors(tmp_path, tokens, error_msg): with pytest.raises(AssertionError, match=error_msg): parse_tokens(tokens_path) + + +@pytest.mark.parametrize( + ("pattern", "entity_types"), + [ + # No end tokens + ( + r"([â’¶â’·â’¸][^â’¶â’·â’¸]*)", + [EntityType(start="â’¶"), EntityType(start="â’·"), EntityType(start="â’¸")], + ), + # With end tokens + ( + r"(â“.*?â’¶)|(â“‘.*?â’·)|(â“’.*?â’¸)", + [ + EntityType(start="â“", end="â’¶"), + EntityType(start="â“‘", end="â’·"), + EntityType(start="â“’", end="â’¸"), + ], + ), + ], +) +def test_parse_tokens_pattern(pattern: str, entity_types: list[EntityType]): + assert parse_tokens_pattern(entity_types).pattern == pattern