From 8ced0010d30bbffa2744da565bda5e1ef9ea04db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Wed, 22 Feb 2023 18:57:17 +0100 Subject: [PATCH] fix lint --- dan/predict/prediction.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index c93297d1..c6c000a7 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -109,7 +109,14 @@ class DAN: input_image = (input_image - self.mean) / self.std return input_image - def predict(self, input_tensor, input_sizes, confidences=False, attentions=False, confidences_sep=None): + def predict( + self, + input_tensor, + input_sizes, + confidences=False, + attentions=False, + confidences_sep=None, + ): """ Run prediction on an input image. :param input_tensor: A batch of images to predict. @@ -188,7 +195,7 @@ class DAN: if torch.all(reached_end): break - + # Concatenate tensors for each token confidence_scores = ( torch.cat(confidence_scores, dim=1).cpu().detach().numpy() @@ -270,13 +277,19 @@ def run( if "word" in confidence_score_levels: word_probs = compute_prob_by_separator(text, char_confidences, ["\n", " "]) - result["confidences"].update({"word": [np.around(c, 2) for c in word_probs]}) + result["confidences"].update( + {"word": [np.around(c, 2) for c in word_probs]} + ) if "line" in confidence_score_levels: line_probs = compute_prob_by_separator(text, char_confidences, ["\n"]) - result["confidences"].update({"line": [np.around(c, 2) for c in line_probs]}) + result["confidences"].update( + {"line": [np.around(c, 2) for c in line_probs]} + ) if "char" in confidence_score_levels: - result["confidences"].update({"char": [np.around(c, 2) for c in char_confidences]}) - + result["confidences"].update( + {"char": [np.around(c, 2) for c in char_confidences]} + ) + # Save gif with attention map if attention_map: gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif" -- GitLab