diff --git a/dan/decoder.py b/dan/decoder.py index 3cb86963dcfbbaa930def033b5499aba2f31cb46..84518ab5e0c322ed164fe8b4f0fd02d26b74c0d8 100644 --- a/dan/decoder.py +++ b/dan/decoder.py @@ -18,8 +18,6 @@ from torch.nn.init import xavier_uniform_ class PositionalEncoding1D(Module): def __init__(self, dim, len_max, device): super(PositionalEncoding1D, self).__init__() - self.len_max = len_max - self.dim = dim self.pe = torch.zeros((1, dim, len_max), device=device, requires_grad=False) div = torch.exp( @@ -46,9 +44,6 @@ class PositionalEncoding1D(Module): class PositionalEncoding2D(Module): def __init__(self, dim, h_max, w_max, device): super(PositionalEncoding2D, self).__init__() - self.h_max = h_max - self.max_w = w_max - self.dim = dim self.pe = torch.zeros( (1, dim, h_max, w_max), device=device, requires_grad=False ) @@ -177,31 +172,28 @@ class GlobalDecoderLayer(Module): def __init__(self, params): super(GlobalDecoderLayer, self).__init__() - self.emb_dim = params["enc_dim"] - self.dim_feedforward = params["dec_dim_feedforward"] - self.self_att = CustomMultiHeadAttention( - embed_dim=self.emb_dim, + embed_dim=params["enc_dim"], num_heads=params["dec_num_heads"], proj_value=True, dropout=params["dec_att_dropout"], ) - self.norm1 = LayerNorm(self.emb_dim) + self.norm1 = LayerNorm(params["enc_dim"]) self.att = CustomMultiHeadAttention( - embed_dim=self.emb_dim, + embed_dim=params["enc_dim"], num_heads=params["dec_num_heads"], proj_value=True, dropout=params["dec_att_dropout"], ) - self.linear1 = Linear(self.emb_dim, self.dim_feedforward) - self.linear2 = Linear(self.dim_feedforward, self.emb_dim) + self.linear1 = Linear(params["enc_dim"], params["dec_dim_feedforward"]) + self.linear2 = Linear(params["dec_dim_feedforward"], params["enc_dim"]) self.dropout = Dropout(params["dec_res_dropout"]) - self.norm2 = LayerNorm(self.emb_dim) - self.norm3 = LayerNorm(self.emb_dim) + self.norm2 = LayerNorm(params["enc_dim"]) + self.norm3 = LayerNorm(params["enc_dim"]) def forward( self, @@ -319,11 +311,8 @@ class FeaturesUpdater(Module): def __init__(self, params): super(FeaturesUpdater, self).__init__() - self.enc_dim = params["enc_dim"] - self.enc_h_max = params["h_max"] - self.enc_w_max = params["w_max"] self.pe_2d = PositionalEncoding2D( - self.enc_dim, self.enc_h_max, self.enc_w_max, params["device"] + params["enc_dim"], params["h_max"], params["w_max"], params["device"] ) self.use_2d_positional_encoding = ( "use_2d_pe" not in params or params["use_2d_pe"] @@ -342,9 +331,6 @@ class GlobalHTADecoder(Module): def __init__(self, params): super(GlobalHTADecoder, self).__init__() - self.enc_dim = params["enc_dim"] - self.dec_l_max = params["l_max"] - self.dropout = Dropout(params["dec_pred_dropout"]) self.dec_att_win = ( params["attention_win"] if params["attention_win"] is not None else 1 @@ -356,17 +342,17 @@ class GlobalHTADecoder(Module): self.att_decoder = GlobalAttDecoder(params) self.emb = Embedding( - num_embeddings=params["vocab_size"] + 3, embedding_dim=self.enc_dim + num_embeddings=params["vocab_size"] + 3, embedding_dim=params["enc_dim"] ) self.pe_1d = PositionalEncoding1D( - self.enc_dim, self.dec_l_max, params["device"] + params["enc_dim"], params["l_max"], params["device"] ) if self.use_lstm: - self.lstm_predict = LSTM(self.enc_dim, self.enc_dim) + self.lstm_predict = LSTM(params["enc_dim"], params["enc_dim"]) vocab_size = params["vocab_size"] + 1 - self.end_conv = Conv1d(self.enc_dim, vocab_size, kernel_size=1) + self.end_conv = Conv1d(params["enc_dim"], vocab_size, kernel_size=1) def forward( self,