diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index 3b9fa98c11a79877d068e1ea9ced4c3b70650182..3726679eaebb9d7cdde458413f10df230f22e529 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -282,9 +282,9 @@ def process_image( logger.debug("Image pre-processed.") # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1 - input_tensor = torch.tensor(im_p).permute(2, 0, 1).unsqueeze(0) + input_tensor = im_p.unsqueeze(0) input_tensor = input_tensor.to(device) - input_sizes = [im.shape[:2]] + input_sizes = [im_p.shape[1:]] # Parse delimiters to regex word_separators = parse_delimiters(word_separators) @@ -421,7 +421,7 @@ def run( cuda_device = f":{gpu_device}" if gpu_device is not None else "" device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" dan_model = DAN(device, temperature) - dan_model.load(model, parameters, charset, mode="train") + dan_model.load(model, parameters, charset, mode="eval") if image: process_image( image,