From 60f2f9ed38ea091786a77d110d8910e3bf99aa5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Wed, 20 Sep 2023 14:48:36 +0200 Subject: [PATCH] Support batch_size>1 --- dan/ocr/decoder.py | 2 +- dan/ocr/predict/inference.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index 568a7cc1..1b8b89d2 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -527,7 +527,7 @@ class CTCLanguageDecoder: ctc_probs = ctc_probs.log() # Insert CTC tensor between frames - for fn in range(batch_frames[0] - 1): + for fn in range(batch_frames.max() - 1): batch_features = torch.cat( [ batch_features[:, : 2 * fn + 1, :], diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index b04995d0..218117b0 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -473,14 +473,10 @@ def run( cuda_device = f":{gpu_device}" if gpu_device is not None else "" device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" dan_model = DAN(device, temperature) -<<<<<<< HEAD - dan_model.load(model, parameters, charset, mode="eval") -======= dan_model.load( model, parameters, charset, mode="eval", use_language_model=use_language_model ) batch_size = 1 if use_language_model else batch_size ->>>>>>> e7c611f (Fix shape of tot_prob) images = image_dir.rglob(f"*{image_extension}") if not image else [image] for image_batch in list_to_batches(images, n=batch_size): -- GitLab