From 8ced0010d30bbffa2744da565bda5e1ef9ea04db Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Wed, 22 Feb 2023 18:57:17 +0100
Subject: [PATCH] fix lint

---
 dan/predict/prediction.py | 25 +++++++++++++++++++------
 1 file changed, 19 insertions(+), 6 deletions(-)

diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index c93297d1..c6c000a7 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -109,7 +109,14 @@ class DAN:
         input_image = (input_image - self.mean) / self.std
         return input_image
 
-    def predict(self, input_tensor, input_sizes, confidences=False, attentions=False, confidences_sep=None):
+    def predict(
+        self,
+        input_tensor,
+        input_sizes,
+        confidences=False,
+        attentions=False,
+        confidences_sep=None,
+    ):
         """
         Run prediction on an input image.
         :param input_tensor: A batch of images to predict.
@@ -188,7 +195,7 @@ class DAN:
 
                 if torch.all(reached_end):
                     break
-            
+
             # Concatenate tensors for each token
             confidence_scores = (
                 torch.cat(confidence_scores, dim=1).cpu().detach().numpy()
@@ -270,13 +277,19 @@ def run(
 
         if "word" in confidence_score_levels:
             word_probs = compute_prob_by_separator(text, char_confidences, ["\n", " "])
-            result["confidences"].update({"word": [np.around(c, 2) for c in word_probs]})
+            result["confidences"].update(
+                {"word": [np.around(c, 2) for c in word_probs]}
+            )
         if "line" in confidence_score_levels:
             line_probs = compute_prob_by_separator(text, char_confidences, ["\n"])
-            result["confidences"].update({"line": [np.around(c, 2) for c in line_probs]})
+            result["confidences"].update(
+                {"line": [np.around(c, 2) for c in line_probs]}
+            )
         if "char" in confidence_score_levels:
-            result["confidences"].update({"char": [np.around(c, 2) for c in char_confidences]})
-                        
+            result["confidences"].update(
+                {"char": [np.around(c, 2) for c in char_confidences]}
+            )
+
     # Save gif with attention map
     if attention_map:
         gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif"
-- 
GitLab