Skip to content
Snippets Groups Projects
Commit 828a5db1 authored by Solene Tarride's avatar Solene Tarride
Browse files

compute word/line confidence scores

parent 68434177
No related branches found
No related tags found
1 merge request!66Compute confidence scores by char, word or line
This commit is part of merge request !66. Comments created here will be created in the context of that merge request.
......@@ -61,6 +61,12 @@ def add_predict_parser(subcommands) -> None:
help="Whether to return confidence scores.",
required=False,
)
parser.add_argument(
"--confidence-score-levels",
default=[],
help="Levels of confidence scores. Should be a list of any combinaison of ['char', 'word', 'line'].",
required=False,
)
parser.add_argument(
"--attention-map",
action="store_true",
......
......@@ -17,6 +17,30 @@ from dan.predict.attention import plot_attention
from dan.utils import read_image
def compute_prob_by_separator(characters, probabilities, separators=["\n"]):
"""
Split text and confidences using separators and return a list of average confidence scores.
:param characters: list of characters.
:param probabilities: list of probabilities.
:param separators: list of characters to split text. Use ["\n", " "] for word confidences and ["\n"] for line confidences.
Returns a list confidence scores.
"""
probs = []
prob_split = []
text_split = ""
for char, prob in zip(characters, probabilities):
if char not in separators:
prob_split.append(prob)
text_split += char
elif text_split:
probs.append(sum(prob_split) / len(prob_split))
prob_split = []
text_split = ""
if text_split:
probs.append(sum(prob_split) / len(prob_split))
return probs
class DAN:
"""
The DAN class is used to apply a DAN model.
......@@ -85,7 +109,7 @@ class DAN:
input_image = (input_image - self.mean) / self.std
return input_image
def predict(self, input_tensor, input_sizes, confidences=False, attentions=False):
def predict(self, input_tensor, input_sizes, confidences=False, attentions=False, confidences_sep=None):
"""
Run prediction on an input image.
:param input_tensor: A batch of images to predict.
......@@ -164,12 +188,14 @@ class DAN:
if torch.all(reached_end):
break
# Concatenate tensors for each token
confidence_scores = (
torch.cat(confidence_scores, dim=1).cpu().detach().numpy()
)
attention_maps = torch.cat(attention_maps, dim=1).cpu().detach().numpy()
# Remove bot and eot tokens
predicted_tokens = predicted_tokens[:, 1:]
prediction_len[torch.eq(reached_end, False)] = self.max_chars - 1
predicted_tokens = [
......@@ -178,9 +204,12 @@ class DAN:
confidence_scores = [
confidence_scores[i, : prediction_len[i]].tolist() for i in range(b)
]
# Transform tokens to characters
predicted_text = [
LM_ind_to_str(self.charset, t, oov_symbol="") for t in predicted_tokens
]
logger.info("Images processed")
out = {"text": predicted_text}
......@@ -199,6 +228,7 @@ def run(
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
......@@ -230,13 +260,23 @@ def run(
confidences=confidence_score,
attentions=attention_map,
)
result = {"text": prediction["text"][0]}
text = prediction["text"][0]
result = {"text": text}
# Average character-based confidence scores
if confidence_score:
# TODO: select the level for confidence scores (char, word, line, total)
result["confidence"] = np.around(np.mean(prediction["confidences"][0]), 2)
char_confidences = prediction["confidences"][0]
result["confidences"] = {"total": np.around(np.mean(char_confidences), 2)}
if "word" in confidence_score_levels:
word_probs = compute_prob_by_separator(text, char_confidences, ["\n", " "])
result["confidences"].update({"word": [np.around(c, 2) for c in word_probs]})
if "line" in confidence_score_levels:
line_probs = compute_prob_by_separator(text, char_confidences, ["\n"])
result["confidences"].update({"line": [np.around(c, 2) for c in line_probs]})
if "char" in confidence_score_levels:
result["confidences"].update({"char": [np.around(c, 2) for c in char_confidences]})
# Save gif with attention map
if attention_map:
gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment