Padding tokens are not ignored during training
During decoding, a mask is generated with the goal of ignoring padding tokens:
- it is initialized in
GlobalHTADecoder.forward()
: https://gitlab.teklia.com/atr/dan/-/blob/3fe5902d37402bb9ece7ce99c6e5d2933fe38a2e/dan/ocr/decoder.py#L385 - then it is passed to
GlobalAttDecoder.forward()
: https://gitlab.teklia.com/atr/dan/-/blob/3fe5902d37402bb9ece7ce99c6e5d2933fe38a2e/dan/ocr/decoder.py#L395 - which is turn calls
GlobalDecoderLayer.forward()
: https://gitlab.teklia.com/atr/dan/-/blob/3fe5902d37402bb9ece7ce99c6e5d2933fe38a2e/dan/ocr/decoder.py#L271 - which calls
CustomMultiHeadAttention.forward()
: https://gitlab.teklia.com/atr/dan/-/blob/3fe5902d37402bb9ece7ce99c6e5d2933fe38a2e/dan/ocr/decoder.py#L211 (the other call uses the memory mask which is initialized differently https://gitlab.teklia.com/atr/dan/-/blob/3fe5902d37402bb9ece7ce99c6e5d2933fe38a2e/dan/ocr/decoder.py#L223) - which ultimately uses this mask to selectively set some values of
attn_output_weights
tofloat('-inf')
for mask items which areTrue
: https://gitlab.teklia.com/atr/dan/-/blob/3fe5902d37402bb9ece7ce99c6e5d2933fe38a2e/dan/ocr/decoder.py#L135
The trouble lies in the code to generate token mask (GlobalHTADecoder.generate_token_mask()
) which reads as follows (https://gitlab.teklia.com/atr/dan/-/blob/3fe5902d37402bb9ece7ce99c6e5d2933fe38a2e/dan/ocr/decoder.py#L428):
def generate_token_mask(self, token_len, total_size, device):
"""
Generate mask for tokens per sample
"""
batch_size, len_max = total_size
mask = torch.zeros((batch_size, len_max), dtype=torch.bool, device=device)
for i, len_ in enumerate(token_len):
mask[i, :len_] = False
return mask
The mask will always be all False
with this code.
I believe padding tokens (eventually inserted at the end of the sequence) are to be ignored, requiring corresponding element to be set to True
in the mask. This would result in the following change:
mask[i, len_:] = True
However, applying the previous change seems to cause the loss to become nan
, because when appying softmax
on heads's predictions (https://gitlab.teklia.com/atr/dan/-/blob/3fe5902d37402bb9ece7ce99c6e5d2933fe38a2e/dan/ocr/decoder.py#L143) some value will become nan
. This is due to the fact all weights for padding tokens in attn_output_weights
will become -inf
after the masked_fill
operation in CustomMultiHeadAttention.forward()
https://gitlab.teklia.com/atr/dan/-/blob/3fe5902d37402bb9ece7ce99c6e5d2933fe38a2e/dan/ocr/decoder.py#L135.
I believe filtering nan
items after the softmax could work:
attn_output_weights_raw = torch.nan_to_num(attn_output_weights_raw, nan=0.0)
This indeed seems to fix the problem of nan
loss, but I'd like a careful feedback on this before pushing any code, especially regarding the "orientation" of the attn_mask
tensor regarding input tokens order.