Skip to content
Snippets Groups Projects
Verified Commit 2e67db27 authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Apply 6679f473

parent c0ac48fa
No related branches found
No related tags found
No related merge requests found
......@@ -2,16 +2,7 @@
import torch
from torch import relu, softmax
from torch.nn import (
LSTM,
Conv1d,
Dropout,
Embedding,
LayerNorm,
Linear,
Module,
ModuleList,
)
from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList
from torch.nn.init import xavier_uniform_
......@@ -336,7 +327,6 @@ class GlobalHTADecoder(Module):
params["attention_win"] if params["attention_win"] is not None else 1
)
self.use_1d_pe = "use_1d_pe" not in params or params["use_1d_pe"]
self.use_lstm = params["use_lstm"]
self.features_updater = FeaturesUpdater(params)
self.att_decoder = GlobalAttDecoder(params)
......@@ -348,9 +338,6 @@ class GlobalHTADecoder(Module):
params["enc_dim"], params["l_max"], params["device"]
)
if self.use_lstm:
self.lstm_predict = LSTM(params["enc_dim"], params["enc_dim"])
vocab_size = params["vocab_size"] + 1
self.end_conv = Conv1d(params["enc_dim"], vocab_size, kernel_size=1)
......@@ -426,9 +413,6 @@ class GlobalHTADecoder(Module):
keep_all_weights=keep_all_weights,
)
if self.use_lstm:
output, hidden_predict = self.lstm_predict(output, hidden_predict)
dp_output = self.dropout(relu(output))
preds = self.end_conv(dp_output.permute(1, 2, 0))
......
......@@ -153,7 +153,6 @@ def get_config():
"dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers
"use_2d_pe": True, # use 2D positional embedding
"use_1d_pe": True, # use 1D positional embedding
"use_lstm": False,
"attention_win": 100, # length of attention window
# Curriculum dropout
"dropout_scheduler": {
......
......@@ -61,7 +61,6 @@ parameters:
dec_pred_dropout: float
attention_win: int
use_1d_pe: bool
use_lstm: bool
vocab_size: int
h_max: int
w_max: int
......
......@@ -147,7 +147,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
| `model_params.dec_dim_feedforward` | Number of dimensions for feedforward layer in transformer decoder layers. | `int` | `256` |
| `model_params.use_2d_pe` | Whether to use 2D positional embedding. | `bool` | `True` |
| `model_params.use_1d_pe` | Whether to use 1D positional embedding. | `bool` | `True` |
| `model_params.use_lstm` | Whether to use a LSTM layer in the decoder. | `bool` | `False` |
| `model_params.attention_win` | Length of attention window. | `int` | `100` |
| `model_params.dropout_scheduler.function` | Curriculum dropout scheduler. | custom class | `exponential_dropout_scheduler` |
| `model_params.dropout_scheduler.T` | Exponential factor. | `float` | `5e4` |
......
......@@ -101,7 +101,6 @@ def training_config():
"dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers
"use_2d_pe": True, # use 2D positional embedding
"use_1d_pe": True, # use 1D positional embedding
"use_lstm": False,
"attention_win": 100, # length of attention window
# Curriculum dropout
"dropout_scheduler": {
......
......@@ -13,7 +13,6 @@ parameters:
dec_pred_dropout: 0.1
attention_win: 100
use_1d_pe: True
use_lstm: False
vocab_size: 96
h_max: 500
w_max: 1000
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment