From ffe8711e8c71a984a014c70587d781228e7c1dd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com> Date: Wed, 22 Feb 2023 12:26:00 +0000 Subject: [PATCH] Add batch prediction code --- dan/predict.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/dan/predict.py b/dan/predict.py index 753e584d..2f855f72 100644 --- a/dan/predict.py +++ b/dan/predict.py @@ -66,13 +66,11 @@ class DAN: self.mean, self.std = parameters["mean"], parameters["std"] self.max_chars = parameters["max_char_prediction"] - def predict(self, input_image, confidences=False): + def preprocess(self, input_image): """ - Run prediction on an input image. - :param input_image: The image to predict. - :param confidences: Return the characters probabilities. + Preprocess an input_image. + :param input_image: The input image to preprocess. """ - # Preprocess image. assert isinstance( input_image, np.ndarray ), "Input image must be an np.array in RGB" @@ -80,12 +78,17 @@ class DAN: if len(input_image.shape) < 3: input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB) - reduced_size = [input_image.shape[:2]] - input_image = (input_image - self.mean) / self.std - input_image = np.expand_dims(input_image.transpose((2, 0, 1)), axis=0) - input_tensor = torch.from_numpy(input_image).to(self.device) - logging.debug("Image pre-processed") + return input_image + + def predict(self, input_tensor, input_sizes, confidences=False): + """ + Run prediction on an input image. + :param input_tensor: A batch of images to predict. + :param input_sizes: The original images sizes. + :param confidences: Return the characters probabilities. + """ + input_tensor.to(self.device) start_token = len(self.charset) + 1 end_token = len(self.charset) @@ -125,7 +128,7 @@ class DAN: features, enhanced_features, predicted_tokens, - reduced_size, + input_sizes, predicted_tokens_len, features_size, start=0, @@ -169,8 +172,8 @@ class DAN: predicted_text = [ LM_ind_to_str(self.charset, t, oov_symbol="") for t in predicted_tokens ] - logging.info("Image processed") + logging.info("Images processed") if confidences: - return predicted_text[0], confidence_scores[0] - return predicted_text[0] + return predicted_text, confidence_scores + return predicted_text -- GitLab