From e3262ae6b11bc9829b97ff5e034085a71bfab3e5 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Fri, 17 Nov 2023 11:48:47 +0100 Subject: [PATCH] Remove use_language_model and use property instead --- dan/ocr/predict/inference.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index c869d4e1..85f8e081 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -117,6 +117,13 @@ class DAN: ) self.max_chars = parameters["max_char_prediction"] + @property + def use_lm(self) -> bool: + """ + Whether the model decodes with a Language Model + """ + return self.lm_decoder is not None + def preprocess(self, path: str) -> Tuple[torch.Tensor, torch.Tensor]: """ Preprocess an image. @@ -154,7 +161,6 @@ class DAN: tokens: Dict[str, EntityType] = {}, start_token: str = None, max_object_height: int = 50, - use_language_model: bool = False, ) -> dict: """ Run prediction on an input image. @@ -280,7 +286,7 @@ class DAN: out = {} out["text"] = predicted_text - if use_language_model: + if self.use_lm: out["language_model"] = self.lm_decoder(tot_pred, prediction_len) if confidences: out["confidences"] = confidence_scores @@ -321,7 +327,6 @@ def process_batch( max_object_height: int, tokens: Dict[str, EntityType], start_token: str, - use_language_model: bool, ) -> None: input_images, visu_images, input_sizes = [], [], [] logger.info("Loading images...") @@ -355,7 +360,6 @@ def process_batch( tokens=tokens, max_object_height=max_object_height, start_token=start_token, - use_language_model=use_language_model, ) logger.info("Prediction parsing...") @@ -365,7 +369,7 @@ def process_batch( if predicted_text: # Return LM results - if use_language_model: + if dan_model.use_lm: result["language_model"] = { "text": prediction["language_model"]["text"][idx], "confidence": prediction["language_model"]["confidence"][idx], @@ -471,9 +475,6 @@ def run( dan_model = DAN(device, temperature) dan_model.load(model, mode="eval", use_language_model=use_language_model) - # Do not use LM with invalid LM weight - use_language_model = dan_model.lm_decoder is not None - images = image_dir.rglob(f"*{image_extension}") for image_batch in list_to_batches(images, n=batch_size): process_batch( @@ -492,5 +493,4 @@ def run( max_object_height, tokens, start_token, - use_language_model, ) -- GitLab