From 7f25f1373480804a85ae8dd9a25b498d85ad7739 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Tue, 12 Sep 2023 12:03:34 +0200 Subject: [PATCH] Improve documentation --- dan/ocr/decoder.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index 6a158382..59909d14 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -503,10 +503,9 @@ class CTCLanguageDecoder: nbest=1, ) self.temperature = temperature - - self.tokens_to_idx = read_txt(tokens_path).split("\n") - self.ctc_id = self.tokens_to_idx.index(blank_token) self.space_token = sil_token + self.tokens_to_idx = read_txt(tokens_path).split("\n") + self.blank_id = self.tokens_to_idx.index(blank_token) def add_ctc_frames(self, batch_features): """ @@ -514,16 +513,19 @@ class CTCLanguageDecoder: """ batch_size, n_frames, n_tokens = batch_features.shape - # column with 1 probability on CTC token + # Create tensor with high probability CTC token + high_prob = 0.99 + low_prob = 0.01 ctc_probs = ( torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) - * 0.01 + * low_prob / (n_tokens - 1) ) - ctc_probs[:, :, self.ctc_id] = 0.99 + ctc_probs[:, :, self.blank_id] = high_prob ctc_probs = ctc_probs.log() - for i in range(n_frames - 1): + # Insert CTC tensor between frames + for i in range(n_frames): batch_features = torch.cat( [ batch_features[:, : 2 * i + 1, :], -- GitLab