diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index 6a1583824ed05b4664ba70b4504ae54c1ff7e722..59909d146a59efdefa99c8d6b8c095d437f9e309 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -503,10 +503,9 @@ class CTCLanguageDecoder: nbest=1, ) self.temperature = temperature - - self.tokens_to_idx = read_txt(tokens_path).split("\n") - self.ctc_id = self.tokens_to_idx.index(blank_token) self.space_token = sil_token + self.tokens_to_idx = read_txt(tokens_path).split("\n") + self.blank_id = self.tokens_to_idx.index(blank_token) def add_ctc_frames(self, batch_features): """ @@ -514,16 +513,19 @@ class CTCLanguageDecoder: """ batch_size, n_frames, n_tokens = batch_features.shape - # column with 1 probability on CTC token + # Create tensor with high probability CTC token + high_prob = 0.99 + low_prob = 0.01 ctc_probs = ( torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) - * 0.01 + * low_prob / (n_tokens - 1) ) - ctc_probs[:, :, self.ctc_id] = 0.99 + ctc_probs[:, :, self.blank_id] = high_prob ctc_probs = ctc_probs.log() - for i in range(n_frames - 1): + # Insert CTC tensor between frames + for i in range(n_frames): batch_features = torch.cat( [ batch_features[:, : 2 * i + 1, :],