Skip to content
Snippets Groups Projects
Commit 60f2f9ed authored by Solene Tarride's avatar Solene Tarride
Browse files

Support batch_size>1

parent 0815f2bb
No related branches found
No related tags found
1 merge request!287Support subword and word language models
This commit is part of merge request !287. Comments created here will be created in the context of that merge request.
......@@ -527,7 +527,7 @@ class CTCLanguageDecoder:
ctc_probs = ctc_probs.log()
# Insert CTC tensor between frames
for fn in range(batch_frames[0] - 1):
for fn in range(batch_frames.max() - 1):
batch_features = torch.cat(
[
batch_features[:, : 2 * fn + 1, :],
......
......@@ -473,14 +473,10 @@ def run(
cuda_device = f":{gpu_device}" if gpu_device is not None else ""
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
<<<<<<< HEAD
dan_model.load(model, parameters, charset, mode="eval")
=======
dan_model.load(
model, parameters, charset, mode="eval", use_language_model=use_language_model
)
batch_size = 1 if use_language_model else batch_size
>>>>>>> e7c611f (Fix shape of tot_prob)
images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_batch in list_to_batches(images, n=batch_size):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment