diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index fa4bcea0ffa3cf434125fed4f913cbad171e3d4d..64552e9624fc27d24788c0619cda8b86dff1d9c6 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -299,13 +299,13 @@ def run( dan_model = DAN(device, temperature) dan_model.load(model, parameters, charset, mode="eval") - im = read_image(image, scale=scale) - # Load image and pre-process it if image_max_width: - h, w, c = read_image(image, scale=1).shape + _, w, _ = read_image(image, scale=1).shape ratio = image_max_width / w im = read_image(image, ratio) + else: + im = read_image(image, scale=scale) logger.info("Image loaded.") im_p = dan_model.preprocess(im)