Skip to content
Snippets Groups Projects
Verified Commit be01536c authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Remove ununsed decoder attributes

parent 06f97ecc
No related branches found
No related tags found
1 merge request!190Remove ununsed decoder attributes
......@@ -18,8 +18,6 @@ from torch.nn.init import xavier_uniform_
class PositionalEncoding1D(Module):
def __init__(self, dim, len_max, device):
super(PositionalEncoding1D, self).__init__()
self.len_max = len_max
self.dim = dim
self.pe = torch.zeros((1, dim, len_max), device=device, requires_grad=False)
div = torch.exp(
......@@ -46,9 +44,6 @@ class PositionalEncoding1D(Module):
class PositionalEncoding2D(Module):
def __init__(self, dim, h_max, w_max, device):
super(PositionalEncoding2D, self).__init__()
self.h_max = h_max
self.max_w = w_max
self.dim = dim
self.pe = torch.zeros(
(1, dim, h_max, w_max), device=device, requires_grad=False
)
......@@ -177,31 +172,28 @@ class GlobalDecoderLayer(Module):
def __init__(self, params):
super(GlobalDecoderLayer, self).__init__()
self.emb_dim = params["enc_dim"]
self.dim_feedforward = params["dec_dim_feedforward"]
self.self_att = CustomMultiHeadAttention(
embed_dim=self.emb_dim,
embed_dim=params["enc_dim"],
num_heads=params["dec_num_heads"],
proj_value=True,
dropout=params["dec_att_dropout"],
)
self.norm1 = LayerNorm(self.emb_dim)
self.norm1 = LayerNorm(params["enc_dim"])
self.att = CustomMultiHeadAttention(
embed_dim=self.emb_dim,
embed_dim=params["enc_dim"],
num_heads=params["dec_num_heads"],
proj_value=True,
dropout=params["dec_att_dropout"],
)
self.linear1 = Linear(self.emb_dim, self.dim_feedforward)
self.linear2 = Linear(self.dim_feedforward, self.emb_dim)
self.linear1 = Linear(params["enc_dim"], params["dec_dim_feedforward"])
self.linear2 = Linear(params["dec_dim_feedforward"], params["enc_dim"])
self.dropout = Dropout(params["dec_res_dropout"])
self.norm2 = LayerNorm(self.emb_dim)
self.norm3 = LayerNorm(self.emb_dim)
self.norm2 = LayerNorm(params["enc_dim"])
self.norm3 = LayerNorm(params["enc_dim"])
def forward(
self,
......@@ -319,11 +311,8 @@ class FeaturesUpdater(Module):
def __init__(self, params):
super(FeaturesUpdater, self).__init__()
self.enc_dim = params["enc_dim"]
self.enc_h_max = params["h_max"]
self.enc_w_max = params["w_max"]
self.pe_2d = PositionalEncoding2D(
self.enc_dim, self.enc_h_max, self.enc_w_max, params["device"]
params["enc_dim"], params["h_max"], params["w_max"], params["device"]
)
self.use_2d_positional_encoding = (
"use_2d_pe" not in params or params["use_2d_pe"]
......@@ -342,9 +331,6 @@ class GlobalHTADecoder(Module):
def __init__(self, params):
super(GlobalHTADecoder, self).__init__()
self.enc_dim = params["enc_dim"]
self.dec_l_max = params["l_max"]
self.dropout = Dropout(params["dec_pred_dropout"])
self.dec_att_win = (
params["attention_win"] if params["attention_win"] is not None else 1
......@@ -356,17 +342,17 @@ class GlobalHTADecoder(Module):
self.att_decoder = GlobalAttDecoder(params)
self.emb = Embedding(
num_embeddings=params["vocab_size"] + 3, embedding_dim=self.enc_dim
num_embeddings=params["vocab_size"] + 3, embedding_dim=params["enc_dim"]
)
self.pe_1d = PositionalEncoding1D(
self.enc_dim, self.dec_l_max, params["device"]
params["enc_dim"], params["l_max"], params["device"]
)
if self.use_lstm:
self.lstm_predict = LSTM(self.enc_dim, self.enc_dim)
self.lstm_predict = LSTM(params["enc_dim"], params["enc_dim"])
vocab_size = params["vocab_size"] + 1
self.end_conv = Conv1d(self.enc_dim, vocab_size, kernel_size=1)
self.end_conv = Conv1d(params["enc_dim"], vocab_size, kernel_size=1)
def forward(
self,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment