Skip to content

Padding tokens are not ignored during training

During decoding, a mask is generated with the goal of ignoring padding tokens:

The trouble lies in the code to generate token mask (GlobalHTADecoder.generate_token_mask()) which reads as follows (https://gitlab.teklia.com/atr/dan/-/blob/3fe5902d37402bb9ece7ce99c6e5d2933fe38a2e/dan/ocr/decoder.py#L428):

    def generate_token_mask(self, token_len, total_size, device):
        """
        Generate mask for tokens per sample
        """
        batch_size, len_max = total_size
        mask = torch.zeros((batch_size, len_max), dtype=torch.bool, device=device)
        for i, len_ in enumerate(token_len):
            mask[i, :len_] = False
        return mask

The mask will always be all False with this code. I believe padding tokens (eventually inserted at the end of the sequence) are to be ignored, requiring corresponding element to be set to True in the mask. This would result in the following change:

mask[i, len_:] = True

However, applying the previous change seems to cause the loss to become nan, because when appying softmax on heads's predictions (https://gitlab.teklia.com/atr/dan/-/blob/3fe5902d37402bb9ece7ce99c6e5d2933fe38a2e/dan/ocr/decoder.py#L143) some value will become nan. This is due to the fact all weights for padding tokens in attn_output_weights will become -inf after the masked_fill operation in CustomMultiHeadAttention.forward() https://gitlab.teklia.com/atr/dan/-/blob/3fe5902d37402bb9ece7ce99c6e5d2933fe38a2e/dan/ocr/decoder.py#L135.

I believe filtering nan items after the softmax could work:

attn_output_weights_raw = torch.nan_to_num(attn_output_weights_raw, nan=0.0)

This indeed seems to fix the problem of nan loss, but I'd like a careful feedback on this before pushing any code, especially regarding the "orientation" of the attn_mask tensor regarding input tokens order.

Edited by Joseph Chazalon