Skip to content
Snippets Groups Projects
Commit 78a69471 authored by Mélodie Boillet's avatar Mélodie Boillet Committed by Solene Tarride
Browse files

Fix input tensor used during prediction

parent f73ffbc1
No related branches found
No related tags found
1 merge request!194Fix input tensor used during prediction
......@@ -283,9 +283,9 @@ def process_image(
logger.debug("Image pre-processed.")
# Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
input_tensor = torch.tensor(im_p).permute(2, 0, 1).unsqueeze(0)
input_tensor = im_p.unsqueeze(0)
input_tensor = input_tensor.to(device)
input_sizes = [im.shape[:2]]
input_sizes = [im_p.shape[1:]]
# Parse delimiters to regex
word_separators = parse_delimiters(word_separators)
......@@ -422,7 +422,7 @@ def run(
cuda_device = f":{gpu_device}" if gpu_device is not None else ""
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="train")
dan_model.load(model, parameters, charset, mode="eval")
if image:
process_image(
image,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment