Skip to content
Snippets Groups Projects
Commit 8ced0010 authored by Solene Tarride's avatar Solene Tarride
Browse files

fix lint

parent 828a5db1
No related branches found
No related tags found
1 merge request!66Compute confidence scores by char, word or line
......@@ -109,7 +109,14 @@ class DAN:
input_image = (input_image - self.mean) / self.std
return input_image
def predict(self, input_tensor, input_sizes, confidences=False, attentions=False, confidences_sep=None):
def predict(
self,
input_tensor,
input_sizes,
confidences=False,
attentions=False,
confidences_sep=None,
):
"""
Run prediction on an input image.
:param input_tensor: A batch of images to predict.
......@@ -188,7 +195,7 @@ class DAN:
if torch.all(reached_end):
break
# Concatenate tensors for each token
confidence_scores = (
torch.cat(confidence_scores, dim=1).cpu().detach().numpy()
......@@ -270,13 +277,19 @@ def run(
if "word" in confidence_score_levels:
word_probs = compute_prob_by_separator(text, char_confidences, ["\n", " "])
result["confidences"].update({"word": [np.around(c, 2) for c in word_probs]})
result["confidences"].update(
{"word": [np.around(c, 2) for c in word_probs]}
)
if "line" in confidence_score_levels:
line_probs = compute_prob_by_separator(text, char_confidences, ["\n"])
result["confidences"].update({"line": [np.around(c, 2) for c in line_probs]})
result["confidences"].update(
{"line": [np.around(c, 2) for c in line_probs]}
)
if "char" in confidence_score_levels:
result["confidences"].update({"char": [np.around(c, 2) for c in char_confidences]})
result["confidences"].update(
{"char": [np.around(c, 2) for c in char_confidences]}
)
# Save gif with attention map
if attention_map:
gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment