Skip to content
Snippets Groups Projects
Commit 8ced0010 authored by Solene Tarride's avatar Solene Tarride
Browse files

fix lint

parent 828a5db1
No related branches found
No related tags found
1 merge request!66Compute confidence scores by char, word or line
This commit is part of merge request !66. Comments created here will be created in the context of that merge request.
...@@ -109,7 +109,14 @@ class DAN: ...@@ -109,7 +109,14 @@ class DAN:
input_image = (input_image - self.mean) / self.std input_image = (input_image - self.mean) / self.std
return input_image return input_image
def predict(self, input_tensor, input_sizes, confidences=False, attentions=False, confidences_sep=None): def predict(
self,
input_tensor,
input_sizes,
confidences=False,
attentions=False,
confidences_sep=None,
):
""" """
Run prediction on an input image. Run prediction on an input image.
:param input_tensor: A batch of images to predict. :param input_tensor: A batch of images to predict.
...@@ -188,7 +195,7 @@ class DAN: ...@@ -188,7 +195,7 @@ class DAN:
if torch.all(reached_end): if torch.all(reached_end):
break break
# Concatenate tensors for each token # Concatenate tensors for each token
confidence_scores = ( confidence_scores = (
torch.cat(confidence_scores, dim=1).cpu().detach().numpy() torch.cat(confidence_scores, dim=1).cpu().detach().numpy()
...@@ -270,13 +277,19 @@ def run( ...@@ -270,13 +277,19 @@ def run(
if "word" in confidence_score_levels: if "word" in confidence_score_levels:
word_probs = compute_prob_by_separator(text, char_confidences, ["\n", " "]) word_probs = compute_prob_by_separator(text, char_confidences, ["\n", " "])
result["confidences"].update({"word": [np.around(c, 2) for c in word_probs]}) result["confidences"].update(
{"word": [np.around(c, 2) for c in word_probs]}
)
if "line" in confidence_score_levels: if "line" in confidence_score_levels:
line_probs = compute_prob_by_separator(text, char_confidences, ["\n"]) line_probs = compute_prob_by_separator(text, char_confidences, ["\n"])
result["confidences"].update({"line": [np.around(c, 2) for c in line_probs]}) result["confidences"].update(
{"line": [np.around(c, 2) for c in line_probs]}
)
if "char" in confidence_score_levels: if "char" in confidence_score_levels:
result["confidences"].update({"char": [np.around(c, 2) for c in char_confidences]}) result["confidences"].update(
{"char": [np.around(c, 2) for c in char_confidences]}
)
# Save gif with attention map # Save gif with attention map
if attention_map: if attention_map:
gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif" gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment