diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index 2260c2312b54fad5388d528a81f76e62d45e715b..3cff98d701ea06eaf25b8d988775506b02ff11ed 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -218,11 +218,9 @@ def compute_prob_by_separator(characters, probabilities, separator): # match anything except separators, get start and end index pattern = re.compile(f"[^{separator.pattern}]+") matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)] - return [ - np.mean(probabilities[start:end]) - for (start, end) in matches - ] + # Iterate over text pieces and compute mean confidence + return [np.mean(probabilities[start:end]) for (start, end) in matches] def run(