From f5e608cd58a49df59729698d5503e7c859335c66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com> Date: Mon, 3 Apr 2023 13:02:40 +0000 Subject: [PATCH] Add start_token parameter to prediction function --- dan/predict/prediction.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index d213e8bf..55bf817c 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -92,6 +92,7 @@ class DAN: input_sizes, confidences=False, attentions=False, + start_token=None, ): """ Run prediction on an input image. @@ -102,7 +103,9 @@ class DAN: """ input_tensor = input_tensor.to(self.device) - start_token = len(self.charset) + 1 + start_token = ( + self.charset.index(start_token) if start_token else len(self.charset) + 1 + ) end_token = len(self.charset) # Run the prediction. -- GitLab