Skip to content
Snippets Groups Projects
Commit 8ced0010 authored by Solene Tarride's avatar Solene Tarride
Browse files

fix lint

parent 828a5db1
No related branches found
Tags 1.5.2-rc5
1 merge request!66Compute confidence scores by char, word or line
This commit is part of merge request !66. Comments created here will be created in the context of that merge request.
......@@ -109,7 +109,14 @@ class DAN:
input_image = (input_image - self.mean) / self.std
return input_image
def predict(self, input_tensor, input_sizes, confidences=False, attentions=False, confidences_sep=None):
def predict(
self,
input_tensor,
input_sizes,
confidences=False,
attentions=False,
confidences_sep=None,
):
"""
Run prediction on an input image.
:param input_tensor: A batch of images to predict.
......@@ -188,7 +195,7 @@ class DAN:
if torch.all(reached_end):
break
# Concatenate tensors for each token
confidence_scores = (
torch.cat(confidence_scores, dim=1).cpu().detach().numpy()
......@@ -270,13 +277,19 @@ def run(
if "word" in confidence_score_levels:
word_probs = compute_prob_by_separator(text, char_confidences, ["\n", " "])
result["confidences"].update({"word": [np.around(c, 2) for c in word_probs]})
result["confidences"].update(
{"word": [np.around(c, 2) for c in word_probs]}
)
if "line" in confidence_score_levels:
line_probs = compute_prob_by_separator(text, char_confidences, ["\n"])
result["confidences"].update({"line": [np.around(c, 2) for c in line_probs]})
result["confidences"].update(
{"line": [np.around(c, 2) for c in line_probs]}
)
if "char" in confidence_score_levels:
result["confidences"].update({"char": [np.around(c, 2) for c in char_confidences]})
result["confidences"].update(
{"char": [np.around(c, 2) for c in char_confidences]}
)
# Save gif with attention map
if attention_map:
gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment