diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index ac029c84d0891d23e24f095ef5ad2830b0567db3..9aa1cd5f1e5c5490ec5cf38632e03af72e60cf87 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -283,9 +283,9 @@ def process_image( logger.debug("Image pre-processed.") # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1 - input_tensor = torch.tensor(im_p).permute(2, 0, 1).unsqueeze(0) + input_tensor = im_p.unsqueeze(0) input_tensor = input_tensor.to(device) - input_sizes = [im.shape[:2]] + input_sizes = [im_p.shape[1:]] # Parse delimiters to regex word_separators = parse_delimiters(word_separators) @@ -422,7 +422,7 @@ def run( cuda_device = f":{gpu_device}" if gpu_device is not None else "" device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" dan_model = DAN(device, temperature) - dan_model.load(model, parameters, charset, mode="train") + dan_model.load(model, parameters, charset, mode="eval") if image: process_image( image,