From f5e608cd58a49df59729698d5503e7c859335c66 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Mon, 3 Apr 2023 13:02:40 +0000
Subject: [PATCH] Add start_token  parameter to prediction function

---
 dan/predict/prediction.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index d213e8bf..55bf817c 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -92,6 +92,7 @@ class DAN:
         input_sizes,
         confidences=False,
         attentions=False,
+        start_token=None,
     ):
         """
         Run prediction on an input image.
@@ -102,7 +103,9 @@ class DAN:
         """
         input_tensor = input_tensor.to(self.device)
 
-        start_token = len(self.charset) + 1
+        start_token = (
+            self.charset.index(start_token) if start_token else len(self.charset) + 1
+        )
         end_token = len(self.charset)
 
         # Run the prediction.
-- 
GitLab