Skip to content
Snippets Groups Projects

Remove ununsed decoder attributes

Merged Mélodie Boillet requested to merge remove-decoder-attrs into main
1 file
+ 12
26
Compare changes
  • Side-by-side
  • Inline
+ 12
26
@@ -18,8 +18,6 @@ from torch.nn.init import xavier_uniform_
class PositionalEncoding1D(Module):
def __init__(self, dim, len_max, device):
super(PositionalEncoding1D, self).__init__()
self.len_max = len_max
self.dim = dim
self.pe = torch.zeros((1, dim, len_max), device=device, requires_grad=False)
div = torch.exp(
@@ -46,9 +44,6 @@ class PositionalEncoding1D(Module):
class PositionalEncoding2D(Module):
def __init__(self, dim, h_max, w_max, device):
super(PositionalEncoding2D, self).__init__()
self.h_max = h_max
self.max_w = w_max
self.dim = dim
self.pe = torch.zeros(
(1, dim, h_max, w_max), device=device, requires_grad=False
)
@@ -177,31 +172,28 @@ class GlobalDecoderLayer(Module):
def __init__(self, params):
super(GlobalDecoderLayer, self).__init__()
self.emb_dim = params["enc_dim"]
self.dim_feedforward = params["dec_dim_feedforward"]
self.self_att = CustomMultiHeadAttention(
embed_dim=self.emb_dim,
embed_dim=params["enc_dim"],
num_heads=params["dec_num_heads"],
proj_value=True,
dropout=params["dec_att_dropout"],
)
self.norm1 = LayerNorm(self.emb_dim)
self.norm1 = LayerNorm(params["enc_dim"])
self.att = CustomMultiHeadAttention(
embed_dim=self.emb_dim,
embed_dim=params["enc_dim"],
num_heads=params["dec_num_heads"],
proj_value=True,
dropout=params["dec_att_dropout"],
)
self.linear1 = Linear(self.emb_dim, self.dim_feedforward)
self.linear2 = Linear(self.dim_feedforward, self.emb_dim)
self.linear1 = Linear(params["enc_dim"], params["dec_dim_feedforward"])
self.linear2 = Linear(params["dec_dim_feedforward"], params["enc_dim"])
self.dropout = Dropout(params["dec_res_dropout"])
self.norm2 = LayerNorm(self.emb_dim)
self.norm3 = LayerNorm(self.emb_dim)
self.norm2 = LayerNorm(params["enc_dim"])
self.norm3 = LayerNorm(params["enc_dim"])
def forward(
self,
@@ -319,11 +311,8 @@ class FeaturesUpdater(Module):
def __init__(self, params):
super(FeaturesUpdater, self).__init__()
self.enc_dim = params["enc_dim"]
self.enc_h_max = params["h_max"]
self.enc_w_max = params["w_max"]
self.pe_2d = PositionalEncoding2D(
self.enc_dim, self.enc_h_max, self.enc_w_max, params["device"]
params["enc_dim"], params["h_max"], params["w_max"], params["device"]
)
self.use_2d_positional_encoding = (
"use_2d_pe" not in params or params["use_2d_pe"]
@@ -342,9 +331,6 @@ class GlobalHTADecoder(Module):
def __init__(self, params):
super(GlobalHTADecoder, self).__init__()
self.enc_dim = params["enc_dim"]
self.dec_l_max = params["l_max"]
self.dropout = Dropout(params["dec_pred_dropout"])
self.dec_att_win = (
params["attention_win"] if params["attention_win"] is not None else 1
@@ -356,17 +342,17 @@ class GlobalHTADecoder(Module):
self.att_decoder = GlobalAttDecoder(params)
self.emb = Embedding(
num_embeddings=params["vocab_size"] + 3, embedding_dim=self.enc_dim
num_embeddings=params["vocab_size"] + 3, embedding_dim=params["enc_dim"]
)
self.pe_1d = PositionalEncoding1D(
self.enc_dim, self.dec_l_max, params["device"]
params["enc_dim"], params["l_max"], params["device"]
)
if self.use_lstm:
self.lstm_predict = LSTM(self.enc_dim, self.enc_dim)
self.lstm_predict = LSTM(params["enc_dim"], params["enc_dim"])
vocab_size = params["vocab_size"] + 1
self.end_conv = Conv1d(self.enc_dim, vocab_size, kernel_size=1)
self.end_conv = Conv1d(params["enc_dim"], vocab_size, kernel_size=1)
def forward(
self,
Loading