Skip to content
Snippets Groups Projects
Commit ffe8711e authored by Mélodie Boillet's avatar Mélodie Boillet Committed by Yoann Schneider
Browse files

Add batch prediction code

parent e90f1bae
No related branches found
No related tags found
1 merge request!64Add batch prediction code
...@@ -66,13 +66,11 @@ class DAN: ...@@ -66,13 +66,11 @@ class DAN:
self.mean, self.std = parameters["mean"], parameters["std"] self.mean, self.std = parameters["mean"], parameters["std"]
self.max_chars = parameters["max_char_prediction"] self.max_chars = parameters["max_char_prediction"]
def predict(self, input_image, confidences=False): def preprocess(self, input_image):
""" """
Run prediction on an input image. Preprocess an input_image.
:param input_image: The image to predict. :param input_image: The input image to preprocess.
:param confidences: Return the characters probabilities.
""" """
# Preprocess image.
assert isinstance( assert isinstance(
input_image, np.ndarray input_image, np.ndarray
), "Input image must be an np.array in RGB" ), "Input image must be an np.array in RGB"
...@@ -80,12 +78,17 @@ class DAN: ...@@ -80,12 +78,17 @@ class DAN:
if len(input_image.shape) < 3: if len(input_image.shape) < 3:
input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB) input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB)
reduced_size = [input_image.shape[:2]]
input_image = (input_image - self.mean) / self.std input_image = (input_image - self.mean) / self.std
input_image = np.expand_dims(input_image.transpose((2, 0, 1)), axis=0) return input_image
input_tensor = torch.from_numpy(input_image).to(self.device)
logging.debug("Image pre-processed") def predict(self, input_tensor, input_sizes, confidences=False):
"""
Run prediction on an input image.
:param input_tensor: A batch of images to predict.
:param input_sizes: The original images sizes.
:param confidences: Return the characters probabilities.
"""
input_tensor.to(self.device)
start_token = len(self.charset) + 1 start_token = len(self.charset) + 1
end_token = len(self.charset) end_token = len(self.charset)
...@@ -125,7 +128,7 @@ class DAN: ...@@ -125,7 +128,7 @@ class DAN:
features, features,
enhanced_features, enhanced_features,
predicted_tokens, predicted_tokens,
reduced_size, input_sizes,
predicted_tokens_len, predicted_tokens_len,
features_size, features_size,
start=0, start=0,
...@@ -169,8 +172,8 @@ class DAN: ...@@ -169,8 +172,8 @@ class DAN:
predicted_text = [ predicted_text = [
LM_ind_to_str(self.charset, t, oov_symbol="") for t in predicted_tokens LM_ind_to_str(self.charset, t, oov_symbol="") for t in predicted_tokens
] ]
logging.info("Image processed") logging.info("Images processed")
if confidences: if confidences:
return predicted_text[0], confidence_scores[0] return predicted_text, confidence_scores
return predicted_text[0] return predicted_text
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment