Skip to content
Snippets Groups Projects
Commit 60f2f9ed authored by Solene Tarride's avatar Solene Tarride
Browse files

Support batch_size>1

parent 0815f2bb
No related branches found
No related tags found
1 merge request!287Support subword and word language models
...@@ -527,7 +527,7 @@ class CTCLanguageDecoder: ...@@ -527,7 +527,7 @@ class CTCLanguageDecoder:
ctc_probs = ctc_probs.log() ctc_probs = ctc_probs.log()
# Insert CTC tensor between frames # Insert CTC tensor between frames
for fn in range(batch_frames[0] - 1): for fn in range(batch_frames.max() - 1):
batch_features = torch.cat( batch_features = torch.cat(
[ [
batch_features[:, : 2 * fn + 1, :], batch_features[:, : 2 * fn + 1, :],
......
...@@ -473,14 +473,10 @@ def run( ...@@ -473,14 +473,10 @@ def run(
cuda_device = f":{gpu_device}" if gpu_device is not None else "" cuda_device = f":{gpu_device}" if gpu_device is not None else ""
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature) dan_model = DAN(device, temperature)
<<<<<<< HEAD
dan_model.load(model, parameters, charset, mode="eval")
=======
dan_model.load( dan_model.load(
model, parameters, charset, mode="eval", use_language_model=use_language_model model, parameters, charset, mode="eval", use_language_model=use_language_model
) )
batch_size = 1 if use_language_model else batch_size batch_size = 1 if use_language_model else batch_size
>>>>>>> e7c611f (Fix shape of tot_prob)
images = image_dir.rglob(f"*{image_extension}") if not image else [image] images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_batch in list_to_batches(images, n=batch_size): for image_batch in list_to_batches(images, n=batch_size):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment