Skip to content
Snippets Groups Projects
Verified Commit e3262ae6 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Remove use_language_model and use property instead

parent afbe821f
No related branches found
No related tags found
1 merge request!321Remove use_language_model and use property instead
...@@ -117,6 +117,13 @@ class DAN: ...@@ -117,6 +117,13 @@ class DAN:
) )
self.max_chars = parameters["max_char_prediction"] self.max_chars = parameters["max_char_prediction"]
@property
def use_lm(self) -> bool:
"""
Whether the model decodes with a Language Model
"""
return self.lm_decoder is not None
def preprocess(self, path: str) -> Tuple[torch.Tensor, torch.Tensor]: def preprocess(self, path: str) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Preprocess an image. Preprocess an image.
...@@ -154,7 +161,6 @@ class DAN: ...@@ -154,7 +161,6 @@ class DAN:
tokens: Dict[str, EntityType] = {}, tokens: Dict[str, EntityType] = {},
start_token: str = None, start_token: str = None,
max_object_height: int = 50, max_object_height: int = 50,
use_language_model: bool = False,
) -> dict: ) -> dict:
""" """
Run prediction on an input image. Run prediction on an input image.
...@@ -280,7 +286,7 @@ class DAN: ...@@ -280,7 +286,7 @@ class DAN:
out = {} out = {}
out["text"] = predicted_text out["text"] = predicted_text
if use_language_model: if self.use_lm:
out["language_model"] = self.lm_decoder(tot_pred, prediction_len) out["language_model"] = self.lm_decoder(tot_pred, prediction_len)
if confidences: if confidences:
out["confidences"] = confidence_scores out["confidences"] = confidence_scores
...@@ -321,7 +327,6 @@ def process_batch( ...@@ -321,7 +327,6 @@ def process_batch(
max_object_height: int, max_object_height: int,
tokens: Dict[str, EntityType], tokens: Dict[str, EntityType],
start_token: str, start_token: str,
use_language_model: bool,
) -> None: ) -> None:
input_images, visu_images, input_sizes = [], [], [] input_images, visu_images, input_sizes = [], [], []
logger.info("Loading images...") logger.info("Loading images...")
...@@ -355,7 +360,6 @@ def process_batch( ...@@ -355,7 +360,6 @@ def process_batch(
tokens=tokens, tokens=tokens,
max_object_height=max_object_height, max_object_height=max_object_height,
start_token=start_token, start_token=start_token,
use_language_model=use_language_model,
) )
logger.info("Prediction parsing...") logger.info("Prediction parsing...")
...@@ -365,7 +369,7 @@ def process_batch( ...@@ -365,7 +369,7 @@ def process_batch(
if predicted_text: if predicted_text:
# Return LM results # Return LM results
if use_language_model: if dan_model.use_lm:
result["language_model"] = { result["language_model"] = {
"text": prediction["language_model"]["text"][idx], "text": prediction["language_model"]["text"][idx],
"confidence": prediction["language_model"]["confidence"][idx], "confidence": prediction["language_model"]["confidence"][idx],
...@@ -471,9 +475,6 @@ def run( ...@@ -471,9 +475,6 @@ def run(
dan_model = DAN(device, temperature) dan_model = DAN(device, temperature)
dan_model.load(model, mode="eval", use_language_model=use_language_model) dan_model.load(model, mode="eval", use_language_model=use_language_model)
# Do not use LM with invalid LM weight
use_language_model = dan_model.lm_decoder is not None
images = image_dir.rglob(f"*{image_extension}") images = image_dir.rglob(f"*{image_extension}")
for image_batch in list_to_batches(images, n=batch_size): for image_batch in list_to_batches(images, n=batch_size):
process_batch( process_batch(
...@@ -492,5 +493,4 @@ def run( ...@@ -492,5 +493,4 @@ def run(
max_object_height, max_object_height,
tokens, tokens,
start_token, start_token,
use_language_model,
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment