From 80eb95e43ed2ee943fd43e412e8ce1b47b59d15b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sol=C3=A8ne=20Tarride?= <starride@teklia.com> Date: Tue, 12 Sep 2023 11:26:01 +0200 Subject: [PATCH] Add CTC frame between each frames --- dan/ocr/decoder.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/dan/ocr/decoder.py b/dan/ocr/decoder.py index 8f80341a..6a158382 100644 --- a/dan/ocr/decoder.py +++ b/dan/ocr/decoder.py @@ -516,17 +516,19 @@ class CTCLanguageDecoder: # column with 1 probability on CTC token ctc_probs = ( - torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) * 0.1 / n_tokens + torch.ones((batch_size, 1, n_tokens), dtype=torch.float32) + * 0.01 + / (n_tokens - 1) ) - ctc_probs[:, :, self.ctc_id] = 0.9 + ctc_probs[:, :, self.ctc_id] = 0.99 ctc_probs = ctc_probs.log() for i in range(n_frames - 1): batch_features = torch.cat( [ - batch_features[:, 2 * i + 1 :, :], - ctc_probs, batch_features[:, : 2 * i + 1, :], + ctc_probs, + batch_features[:, 2 * i + 1 :, :], ], dim=1, ) @@ -534,14 +536,15 @@ class CTCLanguageDecoder: def post_process(self, hypotheses): """ - Post-process hypotheses to output JSON + Post-process hypotheses to output JSON. Exports only the best hypothesis for each image. """ out = {} - # Export only the best hypothesis + # Replace <space> by an actual space and format string out["text"] = [ "".join(hypothesis[0].words).replace(self.space_token, " ") for hypothesis in hypotheses ] + # Normalize confidence score out["confidence"] = [ np.exp(hypothesis[0].score / hypothesis[0].timesteps[-1].item()) for hypothesis in hypotheses @@ -552,13 +555,12 @@ class CTCLanguageDecoder: """ Decode a feature vector using n-gram language modelling. Args: - features (Any): feature vector of size (n_frame, batch_size, n_tokens). - Can be either a torch.tensor or a torch.nn.utils.rnn.PackedSequence + features (torch.tensor): feature vector of size (batch_size, n_tokens, n_frame). + batch_sizes (Union[List, torch.tensor]): actual length of predictions Returns: - out (Dict[str, List]): a dictionary containing the hypothesis (the list of decoded tokens). - There is no character-based probability. + out (Dict[List]): a dictionary containing the hypotheses. """ - # Reshape from (n_frame, batch_size, n_tokens) to (batch_size, n_frame, n_tokens) + # Reshape from (batch_size, n_tokens, n_frames) to (batch_size, n_frames, n_tokens) batch_features = batch_features.permute((0, 2, 1)) # Apply temperature scaling @@ -566,8 +568,8 @@ class CTCLanguageDecoder: # Apply log softmax batch_features = torch.nn.functional.log_softmax(batch_features, dim=-1) - # batch_features = self.add_ctc_frames(batch_features) - # batch_sizes = batch_features.shape[0] + batch_features = self.add_ctc_frames(batch_features) + batch_sizes *= 2 # No GPU support for torchaudio's ctc_decoder device = torch.device("cpu") -- GitLab