diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index 568a7cc1a867e2a2ba56ff398e135646d848a5af..1b8b89d22c8bb2af66b67888c981778edf9a8e63 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -527,7 +527,7 @@ class CTCLanguageDecoder: ctc_probs = ctc_probs.log() # Insert CTC tensor between frames - for fn in range(batch_frames[0] - 1): + for fn in range(batch_frames.max() - 1): batch_features = torch.cat( [ batch_features[:, : 2 * fn + 1, :], diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index b04995d0d3569e93d971f894206a41a66e68f5ad..218117b098da5e62d27d952808678ccaa68a2beb 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -473,14 +473,10 @@ def run( cuda_device = f":{gpu_device}" if gpu_device is not None else "" device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" dan_model = DAN(device, temperature) -<<<<<<< HEAD - dan_model.load(model, parameters, charset, mode="eval") -======= dan_model.load( model, parameters, charset, mode="eval", use_language_model=use_language_model ) batch_size = 1 if use_language_model else batch_size ->>>>>>> e7c611f (Fix shape of tot_prob) images = image_dir.rglob(f"*{image_extension}") if not image else [image] for image_batch in list_to_batches(images, n=batch_size):