Skip to content
Snippets Groups Projects
Commit ffe8711e authored by Mélodie Boillet's avatar Mélodie Boillet Committed by Yoann Schneider
Browse files

Add batch prediction code

parent e90f1bae
No related branches found
No related tags found
1 merge request!64Add batch prediction code
......@@ -66,13 +66,11 @@ class DAN:
self.mean, self.std = parameters["mean"], parameters["std"]
self.max_chars = parameters["max_char_prediction"]
def predict(self, input_image, confidences=False):
def preprocess(self, input_image):
"""
Run prediction on an input image.
:param input_image: The image to predict.
:param confidences: Return the characters probabilities.
Preprocess an input_image.
:param input_image: The input image to preprocess.
"""
# Preprocess image.
assert isinstance(
input_image, np.ndarray
), "Input image must be an np.array in RGB"
......@@ -80,12 +78,17 @@ class DAN:
if len(input_image.shape) < 3:
input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB)
reduced_size = [input_image.shape[:2]]
input_image = (input_image - self.mean) / self.std
input_image = np.expand_dims(input_image.transpose((2, 0, 1)), axis=0)
input_tensor = torch.from_numpy(input_image).to(self.device)
logging.debug("Image pre-processed")
return input_image
def predict(self, input_tensor, input_sizes, confidences=False):
"""
Run prediction on an input image.
:param input_tensor: A batch of images to predict.
:param input_sizes: The original images sizes.
:param confidences: Return the characters probabilities.
"""
input_tensor.to(self.device)
start_token = len(self.charset) + 1
end_token = len(self.charset)
......@@ -125,7 +128,7 @@ class DAN:
features,
enhanced_features,
predicted_tokens,
reduced_size,
input_sizes,
predicted_tokens_len,
features_size,
start=0,
......@@ -169,8 +172,8 @@ class DAN:
predicted_text = [
LM_ind_to_str(self.charset, t, oov_symbol="") for t in predicted_tokens
]
logging.info("Image processed")
logging.info("Images processed")
if confidences:
return predicted_text[0], confidence_scores[0]
return predicted_text[0]
return predicted_text, confidence_scores
return predicted_text
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment