Skip to content
Snippets Groups Projects
Commit 6679f473 authored by Manon Blanco's avatar Manon Blanco Committed by Mélodie Boillet
Browse files

Never use lstm

parent fec09eba
No related branches found
No related tags found
1 merge request!208Never use lstm
...@@ -2,16 +2,7 @@ ...@@ -2,16 +2,7 @@
import torch import torch
from torch import relu, softmax from torch import relu, softmax
from torch.nn import ( from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList
LSTM,
Conv1d,
Dropout,
Embedding,
LayerNorm,
Linear,
Module,
ModuleList,
)
from torch.nn.init import xavier_uniform_ from torch.nn.init import xavier_uniform_
...@@ -336,7 +327,6 @@ class GlobalHTADecoder(Module): ...@@ -336,7 +327,6 @@ class GlobalHTADecoder(Module):
params["attention_win"] if params["attention_win"] is not None else 1 params["attention_win"] if params["attention_win"] is not None else 1
) )
self.use_1d_pe = "use_1d_pe" not in params or params["use_1d_pe"] self.use_1d_pe = "use_1d_pe" not in params or params["use_1d_pe"]
self.use_lstm = params["use_lstm"]
self.features_updater = FeaturesUpdater(params) self.features_updater = FeaturesUpdater(params)
self.att_decoder = GlobalAttDecoder(params) self.att_decoder = GlobalAttDecoder(params)
...@@ -348,9 +338,6 @@ class GlobalHTADecoder(Module): ...@@ -348,9 +338,6 @@ class GlobalHTADecoder(Module):
params["enc_dim"], params["l_max"], params["device"] params["enc_dim"], params["l_max"], params["device"]
) )
if self.use_lstm:
self.lstm_predict = LSTM(params["enc_dim"], params["enc_dim"])
vocab_size = params["vocab_size"] + 1 vocab_size = params["vocab_size"] + 1
self.end_conv = Conv1d(params["enc_dim"], vocab_size, kernel_size=1) self.end_conv = Conv1d(params["enc_dim"], vocab_size, kernel_size=1)
...@@ -426,9 +413,6 @@ class GlobalHTADecoder(Module): ...@@ -426,9 +413,6 @@ class GlobalHTADecoder(Module):
keep_all_weights=keep_all_weights, keep_all_weights=keep_all_weights,
) )
if self.use_lstm:
output, hidden_predict = self.lstm_predict(output, hidden_predict)
dp_output = self.dropout(relu(output)) dp_output = self.dropout(relu(output))
preds = self.end_conv(dp_output.permute(1, 2, 0)) preds = self.end_conv(dp_output.permute(1, 2, 0))
......
...@@ -153,7 +153,6 @@ def get_config(): ...@@ -153,7 +153,6 @@ def get_config():
"dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers "dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers
"use_2d_pe": True, # use 2D positional embedding "use_2d_pe": True, # use 2D positional embedding
"use_1d_pe": True, # use 1D positional embedding "use_1d_pe": True, # use 1D positional embedding
"use_lstm": False,
"attention_win": 100, # length of attention window "attention_win": 100, # length of attention window
# Curriculum dropout # Curriculum dropout
"dropout_scheduler": { "dropout_scheduler": {
......
...@@ -59,7 +59,6 @@ parameters: ...@@ -59,7 +59,6 @@ parameters:
dec_pred_dropout: float dec_pred_dropout: float
attention_win: int attention_win: int
use_1d_pe: bool use_1d_pe: bool
use_lstm: bool
vocab_size: int vocab_size: int
h_max: int h_max: int
w_max: int w_max: int
......
...@@ -138,7 +138,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa ...@@ -138,7 +138,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
| `model_params.dec_dim_feedforward` | Number of dimensions for feedforward layer in transformer decoder layers. | `int` | `256` | | `model_params.dec_dim_feedforward` | Number of dimensions for feedforward layer in transformer decoder layers. | `int` | `256` |
| `model_params.use_2d_pe` | Whether to use 2D positional embedding. | `bool` | `True` | | `model_params.use_2d_pe` | Whether to use 2D positional embedding. | `bool` | `True` |
| `model_params.use_1d_pe` | Whether to use 1D positional embedding. | `bool` | `True` | | `model_params.use_1d_pe` | Whether to use 1D positional embedding. | `bool` | `True` |
| `model_params.use_lstm` | Whether to use a LSTM layer in the decoder. | `bool` | `False` |
| `model_params.attention_win` | Length of attention window. | `int` | `100` | | `model_params.attention_win` | Length of attention window. | `int` | `100` |
| `model_params.dropout_scheduler.function` | Curriculum dropout scheduler. | custom class | `exponential_dropout_scheduler` | | `model_params.dropout_scheduler.function` | Curriculum dropout scheduler. | custom class | `exponential_dropout_scheduler` |
| `model_params.dropout_scheduler.T` | Exponential factor. | `float` | `5e4` | | `model_params.dropout_scheduler.T` | Exponential factor. | `float` | `5e4` |
......
...@@ -101,7 +101,6 @@ def training_config(): ...@@ -101,7 +101,6 @@ def training_config():
"dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers "dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers
"use_2d_pe": True, # use 2D positional embedding "use_2d_pe": True, # use 2D positional embedding
"use_1d_pe": True, # use 1D positional embedding "use_1d_pe": True, # use 1D positional embedding
"use_lstm": False,
"attention_win": 100, # length of attention window "attention_win": 100, # length of attention window
# Curriculum dropout # Curriculum dropout
"dropout_scheduler": { "dropout_scheduler": {
......
...@@ -11,7 +11,6 @@ parameters: ...@@ -11,7 +11,6 @@ parameters:
dec_pred_dropout: 0.1 dec_pred_dropout: 0.1
attention_win: 100 attention_win: 100
use_1d_pe: True use_1d_pe: True
use_lstm: False
vocab_size: 96 vocab_size: 96
h_max: 500 h_max: 500
w_max: 1000 w_max: 1000
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment