Skip to content
Snippets Groups Projects
Commit dab1d6cd authored by Manon Blanco's avatar Manon Blanco
Browse files

Merge branch 'split-on-charset' into 'main'

Split on charset

Closes #219

See merge request !393
parents 04326cc2 3c9f74f6
No related branches found
No related tags found
1 merge request!393Split on charset
......@@ -75,6 +75,7 @@ def compute_prob_by_separator(
def split_text(
text: str,
level: Level,
char_separators: re.Pattern,
word_separators: re.Pattern,
line_separators: re.Pattern,
tokens_separators: re.Pattern | None = None,
......@@ -83,13 +84,14 @@ def split_text(
Split text into a list of characters, word, or lines.
:param text: Text prediction from DAN
:param level: Level to visualize from [char, word, line, ner]
:param char_separators: Pattern used to find tokens in the charset
:param word_separators: Pattern used to find words
:param line_separators: Pattern used to find lines
:param tokens_separators: Pattern used to find NER entities
"""
match level:
case Level.Char:
text_split = list(text)
text_split = char_separators.findall(text)
# split into words
case Level.Word:
text_split = word_separators.split(text)
......@@ -113,6 +115,7 @@ def split_text_and_confidences(
text: str,
confidences: List[float],
level: Level,
char_separators: re.Pattern,
word_separators: re.Pattern,
line_separators: re.Pattern,
tokens_separators: re.Pattern | None = None,
......@@ -122,13 +125,14 @@ def split_text_and_confidences(
:param text: Text prediction from DAN
:param confidences: Character confidences
:param level: Level to visualize from [char, word, line, ner]
:param char_separators: Pattern used to find tokens of the charset
:param word_separators: Pattern used to find words
:param line_separators: Pattern used to find lines
:param tokens_separators: Pattern used to find NER entities
"""
match level:
case Level.Char:
texts = list(text)
texts = char_separators.findall(text)
case Level.Word:
texts, confidences = compute_prob_by_separator(
text, confidences, word_separators
......@@ -163,6 +167,7 @@ def get_predicted_polygons_with_confidence(
level: Level,
height: int,
width: int,
char_separators: re.Pattern,
max_object_height: int = 50,
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
......@@ -176,6 +181,7 @@ def get_predicted_polygons_with_confidence(
:param level: Level to display (must be in [char, word, line, ner])
:param height: Original image height
:param width: Original image width
:param char_separators: Pattern used to find tokens of the charset
:param max_object_height: Maximum height of predicted objects.
:param word_separators: Pattern used to find words
:param line_separators: Pattern used to find lines
......@@ -183,7 +189,13 @@ def get_predicted_polygons_with_confidence(
"""
# Split text into characters, words or lines
text_list, confidence_list, offsets = split_text_and_confidences(
text, confidences, level, word_separators, line_separators, tokens_separators
text,
confidences,
level,
char_separators,
word_separators,
line_separators,
tokens_separators,
)
max_value = weights.sum(0).max()
......
......@@ -26,6 +26,7 @@ from dan.utils import (
ind_to_token,
list_to_batches,
pad_images,
parse_charset_pattern,
parse_tokens_pattern,
read_image,
)
......@@ -359,6 +360,7 @@ def process_batch(
logger.info("Images preprocessed!")
# Parse delimiters to regex
char_separators = parse_charset_pattern(dan_model.charset)
word_separators = parse_delimiters(word_separators)
line_separators = parse_delimiters(line_separators)
......@@ -409,6 +411,7 @@ def process_batch(
predicted_text,
char_confidences,
level,
char_separators,
word_separators,
line_separators,
ner_separators,
......
......@@ -2,7 +2,7 @@
import json
import re
from argparse import ArgumentTypeError
from itertools import islice
from itertools import islice, takewhile
from operator import attrgetter
from pathlib import Path
from typing import Dict, NamedTuple
......@@ -162,6 +162,33 @@ def parse_tokens_pattern(tokens: list[EntityType]) -> re.Pattern[str]:
return re.compile(rf"([{starting_tokens}][^{starting_tokens}]*)")
def parse_charset_pattern(charset: list[str]) -> re.Pattern[str]:
"""
Use (...) for tokens with a longer length than 1, otherwise, use [...].
Longer words are matched first by the pattern.
"""
tokens = sorted(charset, key=len)
# 1 character length
letters = list(takewhile(lambda t: len(t) == 1, tokens))
pattern = ""
if words := tokens[len(letters) :]:
# More than 1 character length
# (?:...)|(?:...)...
# Create non capturing groups to be able to split with re.findall
pattern += (
r"|".join(f"(?:{re.escape(token)})" for token in reversed(words)) + "|"
)
# [...] used for 1-length tokens
pattern += "[" + r"".join(map(re.escape, letters)) + "]"
return re.compile(pattern)
def read_yaml(yaml_path: str) -> Dict:
"""
Read YAML tokens file.
......
......@@ -5,7 +5,7 @@ from dan.ocr.predict.attention import (
parse_delimiters,
split_text_and_confidences,
)
from dan.utils import EntityType, parse_tokens_pattern
from dan.utils import EntityType, parse_charset_pattern, parse_tokens_pattern
@pytest.mark.parametrize(
......@@ -21,13 +21,13 @@ from dan.utils import EntityType, parse_tokens_pattern
[
# level: char
(
"Tok",
[0.1, 0.2, 0.3],
"To <kyo>",
[0.1, 0.2, 0.3, 0.4],
Level.Char,
None,
["T", "o", "k"],
[0.1, 0.2, 0.3],
[0, 0, 0],
["T", "o", " ", "<kyo>"],
[0.1, 0.2, 0.3, 0.4],
[0, 0, 0, 0],
),
# level: word
(
......@@ -100,10 +100,32 @@ def test_split_text_and_confidences(
mean_confidences: list[list[float]],
expected_offsets: list[int],
):
# Full charset
charset = [
# alphabet
"T",
"o",
"L",
"v",
"e",
"k",
"y",
# Entities
"",
"",
"",
"",
# Special
"<kyo>",
# Punctuation
" ",
]
texts, averages, offsets = split_text_and_confidences(
text=text,
confidences=confidence,
level=level,
char_separators=parse_charset_pattern(charset),
word_separators=parse_delimiters([" ", "\n"]),
line_separators=parse_delimiters(["\n"]),
tokens_separators=parse_tokens_pattern(tokens) if tokens else None,
......
import pytest
import yaml
from dan.utils import EntityType, parse_tokens, parse_tokens_pattern
from dan.utils import (
EntityType,
parse_charset_pattern,
parse_tokens,
parse_tokens_pattern,
)
@pytest.mark.parametrize(
......@@ -60,3 +65,18 @@ def test_parse_tokens_errors(tmp_path, tokens, error_msg):
)
def test_parse_tokens_pattern(pattern: str, entity_types: list[EntityType]):
assert parse_tokens_pattern(entity_types).pattern == pattern
@pytest.mark.parametrize(
("charset", "pattern"),
[
(["a", "b", "c"], r"[abc]"),
(["^", "a", "b", "c"], r"[\^abc]"),
(
["<language>", "[", "<word>", "]", "\\", '"'],
r'(?:<language>)|(?:<word>)|[\[\]\\"]',
),
],
)
def test_parse_charset_pattern(charset, pattern):
assert parse_charset_pattern(charset).pattern == pattern
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment