Skip to content
Snippets Groups Projects
Verified Commit 4cbca724 authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Add start_token parameter to prediction function

parent d8684909
No related branches found
No related tags found
1 merge request!94Add start_token parameter to prediction function
This commit is part of merge request !94. Comments created here will be created in the context of that merge request.
......@@ -6,9 +6,9 @@ import re
import cv2
import numpy as np
import torch
import yaml
from dan import logger
from dan.datasets.extract.utils import save_json
from dan.decoder import GlobalHTADecoder
......@@ -92,6 +92,7 @@ class DAN:
input_sizes,
confidences=False,
attentions=False,
start_token=None,
):
"""
Run prediction on an input image.
......@@ -102,7 +103,9 @@ class DAN:
"""
input_tensor = input_tensor.to(self.device)
start_token = len(self.charset) + 1
start_token = (
self.charset.index(start_token) if start_token else len(self.charset) + 1
)
end_token = len(self.charset)
# Run the prediction.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment