From 96327069ef6cdfdf90644a811591a2b79b8487dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Thu, 23 Feb 2023 13:51:05 +0100 Subject: [PATCH] add arguments for word and line separators --- dan/predict/__init__.py | 19 +++++++++++++++++-- dan/predict/attention.py | 26 ++++++++++++++++++++------ dan/predict/prediction.py | 12 ++++++++++-- 3 files changed, 47 insertions(+), 10 deletions(-) diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py index 4eb9f909..1528d35d 100644 --- a/dan/predict/__init__.py +++ b/dan/predict/__init__.py @@ -63,7 +63,7 @@ def add_predict_parser(subcommands) -> None: ) parser.add_argument( "--confidence-score-levels", - default=[], + default="", type=str, nargs="+", help="Levels of confidence scores. Should be a list of any combinaison of ['char', 'word', 'line'].", @@ -90,5 +90,20 @@ def add_predict_parser(subcommands) -> None: help="Image scaling factor before creating the GIF", required=False, ) - + parser.add_argument( + "--word-separators", + default=[" ", "\n"], + type=str, + nargs="+", + help="String separators used to split text into words.", + required=False, + ) + parser.add_argument( + "--line-separators", + default=["\n"], + type=str, + nargs="+", + help="String separators used to split text into lines.", + required=False, + ) parser.set_defaults(func=run) diff --git a/dan/predict/attention.py b/dan/predict/attention.py index bdfe57a1..aab70dff 100644 --- a/dan/predict/attention.py +++ b/dan/predict/attention.py @@ -6,7 +6,7 @@ from PIL import Image from dan import logger -def split_text(text, level): +def split_text(text, level, word_separators, line_separators): """ Split text into a list of characters, word, or lines. :param text: Text prediction from DAN @@ -18,19 +18,33 @@ def split_text(text, level): offset = 0 # split into words elif level == "word": - text = text.replace("\n", " ") - text_split = text.split(" ") + main_sep = word_separators[0] + for other_sep in word_separators[1:]: + text = text.replace(other_sep, main_sep) + text_split = text.split(main_sep) offset = 1 # split into lines elif level == "line": - text_split = text.split("\n") + main_sep = line_separators[0] + for other_sep in line_separators[1:]: + text = text.replace(other_sep, main_sep) + text_split = text.split(main_sep) offset = 1 else: logger.error("Level should be either 'char', 'word', or 'line'") return text_split, offset -def plot_attention(image, text, weights, level, scale, outname): +def plot_attention( + image, + text, + weights, + level, + scale, + outname, + word_separators=["\n", " "], + line_separators=["\n"], +): """ Create a gif by blending attention maps to the image for each text piece (char, word or line) :param image: Input image in PIL format @@ -48,7 +62,7 @@ def plot_attention(image, text, weights, level, scale, outname): image = Image.fromarray(image) # Split text into characters, words or lines - text_list, offset = split_text(text, level) + text_list, offset = split_text(text, level, word_separators, line_separators) # Iterate on characters, words or lines tot_len = 0 diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index a43ed991..69ce4029 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -238,6 +238,8 @@ def run( attention_map, attention_map_level, attention_map_scale, + word_separators, + line_separators, ): # Create output directory if necessary if not os.path.exists(output): @@ -274,10 +276,14 @@ def run( char_confidences = prediction["confidences"][0] result["confidences"] = {"total": np.around(np.mean(char_confidences), 2)} if "word" in confidence_score_levels: - word_probs = compute_prob_by_separator(text, char_confidences, ["\n", " "]) + word_probs = compute_prob_by_separator( + text, char_confidences, word_separators + ) result["confidences"].update({"word": round_floats(word_probs)}) if "line" in confidence_score_levels: - line_probs = compute_prob_by_separator(text, char_confidences, ["\n"]) + line_probs = compute_prob_by_separator( + text, char_confidences, line_separators + ) result["confidences"].update({"line": round_floats(line_probs)}) if "char" in confidence_score_levels: result["confidences"].update({"char": round_floats(char_confidences)}) @@ -292,6 +298,8 @@ def run( weights=prediction["attentions"][0], level=attention_map_level, scale=attention_map_scale, + word_separators=word_separators, + line_separators=line_separators, outname=gif_filename, ) result["attention_gif"] = gif_filename -- GitLab