From e3262ae6b11bc9829b97ff5e034085a71bfab3e5 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Fri, 17 Nov 2023 11:48:47 +0100
Subject: [PATCH] Remove use_language_model and use property instead

---
 dan/ocr/predict/inference.py | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py
index c869d4e1..85f8e081 100644
--- a/dan/ocr/predict/inference.py
+++ b/dan/ocr/predict/inference.py
@@ -117,6 +117,13 @@ class DAN:
         )
         self.max_chars = parameters["max_char_prediction"]
 
+    @property
+    def use_lm(self) -> bool:
+        """
+        Whether the model decodes with a Language Model
+        """
+        return self.lm_decoder is not None
+
     def preprocess(self, path: str) -> Tuple[torch.Tensor, torch.Tensor]:
         """
         Preprocess an image.
@@ -154,7 +161,6 @@ class DAN:
         tokens: Dict[str, EntityType] = {},
         start_token: str = None,
         max_object_height: int = 50,
-        use_language_model: bool = False,
     ) -> dict:
         """
         Run prediction on an input image.
@@ -280,7 +286,7 @@ class DAN:
         out = {}
 
         out["text"] = predicted_text
-        if use_language_model:
+        if self.use_lm:
             out["language_model"] = self.lm_decoder(tot_pred, prediction_len)
         if confidences:
             out["confidences"] = confidence_scores
@@ -321,7 +327,6 @@ def process_batch(
     max_object_height: int,
     tokens: Dict[str, EntityType],
     start_token: str,
-    use_language_model: bool,
 ) -> None:
     input_images, visu_images, input_sizes = [], [], []
     logger.info("Loading images...")
@@ -355,7 +360,6 @@ def process_batch(
         tokens=tokens,
         max_object_height=max_object_height,
         start_token=start_token,
-        use_language_model=use_language_model,
     )
 
     logger.info("Prediction parsing...")
@@ -365,7 +369,7 @@ def process_batch(
 
         if predicted_text:
             # Return LM results
-            if use_language_model:
+            if dan_model.use_lm:
                 result["language_model"] = {
                     "text": prediction["language_model"]["text"][idx],
                     "confidence": prediction["language_model"]["confidence"][idx],
@@ -471,9 +475,6 @@ def run(
     dan_model = DAN(device, temperature)
     dan_model.load(model, mode="eval", use_language_model=use_language_model)
 
-    # Do not use LM with invalid LM weight
-    use_language_model = dan_model.lm_decoder is not None
-
     images = image_dir.rglob(f"*{image_extension}")
     for image_batch in list_to_batches(images, n=batch_size):
         process_batch(
@@ -492,5 +493,4 @@ def run(
             max_object_height,
             tokens,
             start_token,
-            use_language_model,
         )
-- 
GitLab