Skip to content
Snippets Groups Projects
Verified Commit be01536c authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Remove ununsed decoder attributes

parent 06f97ecc
No related branches found
No related tags found
1 merge request!190Remove ununsed decoder attributes
...@@ -18,8 +18,6 @@ from torch.nn.init import xavier_uniform_ ...@@ -18,8 +18,6 @@ from torch.nn.init import xavier_uniform_
class PositionalEncoding1D(Module): class PositionalEncoding1D(Module):
def __init__(self, dim, len_max, device): def __init__(self, dim, len_max, device):
super(PositionalEncoding1D, self).__init__() super(PositionalEncoding1D, self).__init__()
self.len_max = len_max
self.dim = dim
self.pe = torch.zeros((1, dim, len_max), device=device, requires_grad=False) self.pe = torch.zeros((1, dim, len_max), device=device, requires_grad=False)
div = torch.exp( div = torch.exp(
...@@ -46,9 +44,6 @@ class PositionalEncoding1D(Module): ...@@ -46,9 +44,6 @@ class PositionalEncoding1D(Module):
class PositionalEncoding2D(Module): class PositionalEncoding2D(Module):
def __init__(self, dim, h_max, w_max, device): def __init__(self, dim, h_max, w_max, device):
super(PositionalEncoding2D, self).__init__() super(PositionalEncoding2D, self).__init__()
self.h_max = h_max
self.max_w = w_max
self.dim = dim
self.pe = torch.zeros( self.pe = torch.zeros(
(1, dim, h_max, w_max), device=device, requires_grad=False (1, dim, h_max, w_max), device=device, requires_grad=False
) )
...@@ -177,31 +172,28 @@ class GlobalDecoderLayer(Module): ...@@ -177,31 +172,28 @@ class GlobalDecoderLayer(Module):
def __init__(self, params): def __init__(self, params):
super(GlobalDecoderLayer, self).__init__() super(GlobalDecoderLayer, self).__init__()
self.emb_dim = params["enc_dim"]
self.dim_feedforward = params["dec_dim_feedforward"]
self.self_att = CustomMultiHeadAttention( self.self_att = CustomMultiHeadAttention(
embed_dim=self.emb_dim, embed_dim=params["enc_dim"],
num_heads=params["dec_num_heads"], num_heads=params["dec_num_heads"],
proj_value=True, proj_value=True,
dropout=params["dec_att_dropout"], dropout=params["dec_att_dropout"],
) )
self.norm1 = LayerNorm(self.emb_dim) self.norm1 = LayerNorm(params["enc_dim"])
self.att = CustomMultiHeadAttention( self.att = CustomMultiHeadAttention(
embed_dim=self.emb_dim, embed_dim=params["enc_dim"],
num_heads=params["dec_num_heads"], num_heads=params["dec_num_heads"],
proj_value=True, proj_value=True,
dropout=params["dec_att_dropout"], dropout=params["dec_att_dropout"],
) )
self.linear1 = Linear(self.emb_dim, self.dim_feedforward) self.linear1 = Linear(params["enc_dim"], params["dec_dim_feedforward"])
self.linear2 = Linear(self.dim_feedforward, self.emb_dim) self.linear2 = Linear(params["dec_dim_feedforward"], params["enc_dim"])
self.dropout = Dropout(params["dec_res_dropout"]) self.dropout = Dropout(params["dec_res_dropout"])
self.norm2 = LayerNorm(self.emb_dim) self.norm2 = LayerNorm(params["enc_dim"])
self.norm3 = LayerNorm(self.emb_dim) self.norm3 = LayerNorm(params["enc_dim"])
def forward( def forward(
self, self,
...@@ -319,11 +311,8 @@ class FeaturesUpdater(Module): ...@@ -319,11 +311,8 @@ class FeaturesUpdater(Module):
def __init__(self, params): def __init__(self, params):
super(FeaturesUpdater, self).__init__() super(FeaturesUpdater, self).__init__()
self.enc_dim = params["enc_dim"]
self.enc_h_max = params["h_max"]
self.enc_w_max = params["w_max"]
self.pe_2d = PositionalEncoding2D( self.pe_2d = PositionalEncoding2D(
self.enc_dim, self.enc_h_max, self.enc_w_max, params["device"] params["enc_dim"], params["h_max"], params["w_max"], params["device"]
) )
self.use_2d_positional_encoding = ( self.use_2d_positional_encoding = (
"use_2d_pe" not in params or params["use_2d_pe"] "use_2d_pe" not in params or params["use_2d_pe"]
...@@ -342,9 +331,6 @@ class GlobalHTADecoder(Module): ...@@ -342,9 +331,6 @@ class GlobalHTADecoder(Module):
def __init__(self, params): def __init__(self, params):
super(GlobalHTADecoder, self).__init__() super(GlobalHTADecoder, self).__init__()
self.enc_dim = params["enc_dim"]
self.dec_l_max = params["l_max"]
self.dropout = Dropout(params["dec_pred_dropout"]) self.dropout = Dropout(params["dec_pred_dropout"])
self.dec_att_win = ( self.dec_att_win = (
params["attention_win"] if params["attention_win"] is not None else 1 params["attention_win"] if params["attention_win"] is not None else 1
...@@ -356,17 +342,17 @@ class GlobalHTADecoder(Module): ...@@ -356,17 +342,17 @@ class GlobalHTADecoder(Module):
self.att_decoder = GlobalAttDecoder(params) self.att_decoder = GlobalAttDecoder(params)
self.emb = Embedding( self.emb = Embedding(
num_embeddings=params["vocab_size"] + 3, embedding_dim=self.enc_dim num_embeddings=params["vocab_size"] + 3, embedding_dim=params["enc_dim"]
) )
self.pe_1d = PositionalEncoding1D( self.pe_1d = PositionalEncoding1D(
self.enc_dim, self.dec_l_max, params["device"] params["enc_dim"], params["l_max"], params["device"]
) )
if self.use_lstm: if self.use_lstm:
self.lstm_predict = LSTM(self.enc_dim, self.enc_dim) self.lstm_predict = LSTM(params["enc_dim"], params["enc_dim"])
vocab_size = params["vocab_size"] + 1 vocab_size = params["vocab_size"] + 1
self.end_conv = Conv1d(self.enc_dim, vocab_size, kernel_size=1) self.end_conv = Conv1d(params["enc_dim"], vocab_size, kernel_size=1)
def forward( def forward(
self, self,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment