diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index d213e8bf8f358b98264e1046fc25c0ffc0f04c61..55bf817c77d188951c15c69e38cd41fa245aee79 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -92,6 +92,7 @@ class DAN: input_sizes, confidences=False, attentions=False, + start_token=None, ): """ Run prediction on an input image. @@ -102,7 +103,9 @@ class DAN: """ input_tensor = input_tensor.to(self.device) - start_token = len(self.charset) + 1 + start_token = ( + self.charset.index(start_token) if start_token else len(self.charset) + 1 + ) end_token = len(self.charset) # Run the prediction.