From 60f2f9ed38ea091786a77d110d8910e3bf99aa5e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com>
Date: Wed, 20 Sep 2023 14:48:36 +0200
Subject: [PATCH] Support batch_size>1

---
 dan/ocr/decoder.py           | 2 +-
 dan/ocr/predict/inference.py | 4 ----
 2 files changed, 1 insertion(+), 5 deletions(-)

diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py
index 568a7cc1..1b8b89d2 100644
--- a/dan/ocr/decoder.py
+++ b/dan/ocr/decoder.py
@@ -527,7 +527,7 @@ class CTCLanguageDecoder:
         ctc_probs = ctc_probs.log()
 
         # Insert CTC tensor between frames
-        for fn in range(batch_frames[0] - 1):
+        for fn in range(batch_frames.max() - 1):
             batch_features = torch.cat(
                 [
                     batch_features[:, : 2 * fn + 1, :],
diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py
index b04995d0..218117b0 100644
--- a/dan/ocr/predict/inference.py
+++ b/dan/ocr/predict/inference.py
@@ -473,14 +473,10 @@ def run(
     cuda_device = f":{gpu_device}" if gpu_device is not None else ""
     device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
     dan_model = DAN(device, temperature)
-<<<<<<< HEAD
-    dan_model.load(model, parameters, charset, mode="eval")
-=======
     dan_model.load(
         model, parameters, charset, mode="eval", use_language_model=use_language_model
     )
     batch_size = 1 if use_language_model else batch_size
->>>>>>> e7c611f (Fix shape of tot_prob)
 
     images = image_dir.rglob(f"*{image_extension}") if not image else [image]
     for image_batch in list_to_batches(images, n=batch_size):
-- 
GitLab