Skip to content
Snippets Groups Projects
Commit 517510aa authored by Solene Tarride's avatar Solene Tarride Committed by Solene Tarride
Browse files

Fix tests

parent f8630f6e
No related branches found
No related tags found
No related merge requests found
......@@ -169,7 +169,7 @@ def add_predict_parser(subcommands) -> None:
)
parser.add_argument(
"--use-language-model",
help="Whether to use an explicit language model to rescore text hypothesis.",
help="Whether to use an explicit language model to rescore text hypotheses.",
action="store_true",
required=False,
)
......
......@@ -77,16 +77,6 @@ class DAN:
decoder = GlobalHTADecoder(parameters["decoder"]).to(self.device)
decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=True)
self.lm_decoder = CTCLanguageDecoder(
language_model_path=parameters["lm_decoder"]["language_model_path"],
lexicon_path=parameters["lm_decoder"]["lexicon_path"],
tokens_path=parameters["lm_decoder"]["tokens_path"],
language_model_weight=parameters["lm_decoder"]["language_model_weight"],
blank_token=parameters["lm_decoder"]["blank_token"],
unk_token=parameters["lm_decoder"]["unk_token"],
sil_token=parameters["lm_decoder"]["sil_token"],
)
logger.debug(f"Loaded model {model_path}")
if mode == "train":
......@@ -100,17 +90,16 @@ class DAN:
self.encoder = encoder
self.decoder = decoder
self.lm_decoder = None
if use_language_model and parameters["language_model"]["weight"] > 0:
logger.info(
f"Decoding with a language model (weight={parameters['language_model']['weight']})."
)
if use_language_model:
self.lm_decoder = CTCLanguageDecoder(
language_model_path=parameters["language_model"]["model"],
lexicon_path=parameters["language_model"]["lexicon"],
tokens_path=parameters["language_model"]["tokens"],
language_model_weight=parameters["language_model"]["weight"],
language_model_path=parameters["lm_decoder"]["language_model_path"],
lexicon_path=parameters["lm_decoder"]["lexicon_path"],
tokens_path=parameters["lm_decoder"]["tokens_path"],
language_model_weight=parameters["lm_decoder"]["language_model_weight"],
blank_token=parameters["lm_decoder"]["blank_token"],
unk_token=parameters["lm_decoder"]["unk_token"],
sil_token=parameters["lm_decoder"]["sil_token"],
)
self.mean, self.std = (
......@@ -370,7 +359,6 @@ def process_batch(
# Return LM results
if use_language_model:
result["language_model"] = {}
print(prediction)
result["language_model"]["text"] = prediction["language_model"]["text"][idx]
result["language_model"]["confidence"] = prediction["language_model"][
"confidence"
......@@ -477,7 +465,7 @@ def run(
:param batch_size: Size of the batches for prediction.
:param tokens: NER tokens used.
:param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.
:param use_language_model: Whether to use an explicit language model to rescore text hypothesis.
:param use_language_model: Whether to use an explicit language model to rescore text hypotheses.
"""
# Create output directory if necessary
if not output.exists():
......@@ -487,12 +475,7 @@ def run(
cuda_device = f":{gpu_device}" if gpu_device is not None else ""
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
dan_model.load(
model, parameters, charset, mode="eval", use_language_model=use_language_model
)
# Do not use LM with invalid LM weight
use_language_model = dan_model.lm_decoder is not None
dan_model.load(model, parameters, charset, mode="eval")
images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_batch in list_to_batches(images, n=batch_size):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment