diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index d213e8bf8f358b98264e1046fc25c0ffc0f04c61..55bf817c77d188951c15c69e38cd41fa245aee79 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -92,6 +92,7 @@ class DAN:
         input_sizes,
         confidences=False,
         attentions=False,
+        start_token=None,
     ):
         """
         Run prediction on an input image.
@@ -102,7 +103,9 @@ class DAN:
         """
         input_tensor = input_tensor.to(self.device)
 
-        start_token = len(self.charset) + 1
+        start_token = (
+            self.charset.index(start_token) if start_token else len(self.charset) + 1
+        )
         end_token = len(self.charset)
 
         # Run the prediction.