Skip to content
Snippets Groups Projects
Verified Commit e3262ae6 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Remove use_language_model and use property instead

parent afbe821f
No related branches found
No related tags found
1 merge request!321Remove use_language_model and use property instead
......@@ -117,6 +117,13 @@ class DAN:
)
self.max_chars = parameters["max_char_prediction"]
@property
def use_lm(self) -> bool:
"""
Whether the model decodes with a Language Model
"""
return self.lm_decoder is not None
def preprocess(self, path: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Preprocess an image.
......@@ -154,7 +161,6 @@ class DAN:
tokens: Dict[str, EntityType] = {},
start_token: str = None,
max_object_height: int = 50,
use_language_model: bool = False,
) -> dict:
"""
Run prediction on an input image.
......@@ -280,7 +286,7 @@ class DAN:
out = {}
out["text"] = predicted_text
if use_language_model:
if self.use_lm:
out["language_model"] = self.lm_decoder(tot_pred, prediction_len)
if confidences:
out["confidences"] = confidence_scores
......@@ -321,7 +327,6 @@ def process_batch(
max_object_height: int,
tokens: Dict[str, EntityType],
start_token: str,
use_language_model: bool,
) -> None:
input_images, visu_images, input_sizes = [], [], []
logger.info("Loading images...")
......@@ -355,7 +360,6 @@ def process_batch(
tokens=tokens,
max_object_height=max_object_height,
start_token=start_token,
use_language_model=use_language_model,
)
logger.info("Prediction parsing...")
......@@ -365,7 +369,7 @@ def process_batch(
if predicted_text:
# Return LM results
if use_language_model:
if dan_model.use_lm:
result["language_model"] = {
"text": prediction["language_model"]["text"][idx],
"confidence": prediction["language_model"]["confidence"][idx],
......@@ -471,9 +475,6 @@ def run(
dan_model = DAN(device, temperature)
dan_model.load(model, mode="eval", use_language_model=use_language_model)
# Do not use LM with invalid LM weight
use_language_model = dan_model.lm_decoder is not None
images = image_dir.rglob(f"*{image_extension}")
for image_batch in list_to_batches(images, n=batch_size):
process_batch(
......@@ -492,5 +493,4 @@ def run(
max_object_height,
tokens,
start_token,
use_language_model,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment