Skip to content
Snippets Groups Projects

Add batch prediction code

Merged Mélodie Boillet requested to merge batch-prediction into main
1 file
+ 17
14
Compare changes
  • Side-by-side
  • Inline
+ 17
14
@@ -66,13 +66,11 @@ class DAN:
self.mean, self.std = parameters["mean"], parameters["std"]
self.max_chars = parameters["max_char_prediction"]
def predict(self, input_image, confidences=False):
def preprocess(self, input_image):
"""
Run prediction on an input image.
:param input_image: The image to predict.
:param confidences: Return the characters probabilities.
Preprocess an input_image.
:param input_image: The input image to preprocess.
"""
# Preprocess image.
assert isinstance(
input_image, np.ndarray
), "Input image must be an np.array in RGB"
@@ -80,12 +78,17 @@ class DAN:
if len(input_image.shape) < 3:
input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB)
reduced_size = [input_image.shape[:2]]
input_image = (input_image - self.mean) / self.std
input_image = np.expand_dims(input_image.transpose((2, 0, 1)), axis=0)
input_tensor = torch.from_numpy(input_image).to(self.device)
logging.debug("Image pre-processed")
return input_image
def predict(self, input_tensor, input_sizes, confidences=False):
"""
Run prediction on an input image.
:param input_tensor: A batch of images to predict.
:param input_sizes: The original images sizes.
:param confidences: Return the characters probabilities.
"""
input_tensor.to(self.device)
start_token = len(self.charset) + 1
end_token = len(self.charset)
@@ -125,7 +128,7 @@ class DAN:
features,
enhanced_features,
predicted_tokens,
reduced_size,
input_sizes,
predicted_tokens_len,
features_size,
start=0,
@@ -169,8 +172,8 @@ class DAN:
predicted_text = [
LM_ind_to_str(self.charset, t, oov_symbol="") for t in predicted_tokens
]
logging.info("Image processed")
logging.info("Images processed")
if confidences:
return predicted_text[0], confidence_scores[0]
return predicted_text[0]
return predicted_text, confidence_scores
return predicted_text
Loading