diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index b9b256d1c28ca72c5d9c280f1feec88d2cf9bf26..6dfda6d9fb39f2885f1d42b1ea39661bcb6b8a3f 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -495,6 +495,7 @@ class CTCLanguageDecoder: self.tokens_to_index = { token: i for i, token in enumerate(read_txt(tokens_path).split("\n")) } + self.index_to_token = {i: token for token, i in self.tokens_to_index.items()} self.blank_token_id = self.tokens_to_index[self.mapping.ctc.encoded] self.decoder = ctc_decoder( lm=language_model_path, @@ -516,7 +517,6 @@ class CTCLanguageDecoder: high_prob = batch_features.max() low_prob = batch_features.min() batch_size, n_frames, n_tokens = batch_features.shape - # Reset probabilities for the CTC token batch_features[:, :, -1] = ( torch.ones((batch_size, n_frames), dtype=torch.float32) * low_prob @@ -553,10 +553,10 @@ class CTCLanguageDecoder: out["text"] = [ "".join( [ - self.mapping.display[token] - if token in self.mapping.display - else token - for token in hypothesis[0].words + self.mapping.display[self.index_to_token[token]] + if self.index_to_token[token] in self.mapping.display + else self.index_to_token[token] + for token in hypothesis[0].tokens.tolist() ] ) for hypothesis in hypotheses