diff --git a/dan/decoder.py b/dan/decoder.py index 84518ab5e0c322ed164fe8b4f0fd02d26b74c0d8..af5c28cbba2f3de2076cf5113c8d8a9277294c3b 100644 --- a/dan/decoder.py +++ b/dan/decoder.py @@ -2,16 +2,7 @@ import torch from torch import relu, softmax -from torch.nn import ( - LSTM, - Conv1d, - Dropout, - Embedding, - LayerNorm, - Linear, - Module, - ModuleList, -) +from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList from torch.nn.init import xavier_uniform_ @@ -336,7 +327,6 @@ class GlobalHTADecoder(Module): params["attention_win"] if params["attention_win"] is not None else 1 ) self.use_1d_pe = "use_1d_pe" not in params or params["use_1d_pe"] - self.use_lstm = params["use_lstm"] self.features_updater = FeaturesUpdater(params) self.att_decoder = GlobalAttDecoder(params) @@ -348,9 +338,6 @@ class GlobalHTADecoder(Module): params["enc_dim"], params["l_max"], params["device"] ) - if self.use_lstm: - self.lstm_predict = LSTM(params["enc_dim"], params["enc_dim"]) - vocab_size = params["vocab_size"] + 1 self.end_conv = Conv1d(params["enc_dim"], vocab_size, kernel_size=1) @@ -426,9 +413,6 @@ class GlobalHTADecoder(Module): keep_all_weights=keep_all_weights, ) - if self.use_lstm: - output, hidden_predict = self.lstm_predict(output, hidden_predict) - dp_output = self.dropout(relu(output)) preds = self.end_conv(dp_output.permute(1, 2, 0)) diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 63b5080a38857a483c3d205296a828b1cfc8d082..6f5e17b36816c7f8dd22008de8c6941e4a7c2ae0 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -153,7 +153,6 @@ def get_config(): "dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers "use_2d_pe": True, # use 2D positional embedding "use_1d_pe": True, # use 1D positional embedding - "use_lstm": False, "attention_win": 100, # length of attention window # Curriculum dropout "dropout_scheduler": { diff --git a/docs/get_started/training.md b/docs/get_started/training.md index 71f672d79a3a82efdab223ecf04e20c8f2292c58..c92045bcab40971c144017dec3ba0f705a244180 100644 --- a/docs/get_started/training.md +++ b/docs/get_started/training.md @@ -61,7 +61,6 @@ parameters: dec_pred_dropout: float attention_win: int use_1d_pe: bool - use_lstm: bool vocab_size: int h_max: int w_max: int diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md index e2aad9531cad0f0775e33fa93ae9d9ac5283a6a2..981ab842ef6ac62dadb60dd7815f86cdd7d2fd89 100644 --- a/docs/usage/train/parameters.md +++ b/docs/usage/train/parameters.md @@ -147,7 +147,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa | `model_params.dec_dim_feedforward` | Number of dimensions for feedforward layer in transformer decoder layers. | `int` | `256` | | `model_params.use_2d_pe` | Whether to use 2D positional embedding. | `bool` | `True` | | `model_params.use_1d_pe` | Whether to use 1D positional embedding. | `bool` | `True` | -| `model_params.use_lstm` | Whether to use a LSTM layer in the decoder. | `bool` | `False` | | `model_params.attention_win` | Length of attention window. | `int` | `100` | | `model_params.dropout_scheduler.function` | Curriculum dropout scheduler. | custom class | `exponential_dropout_scheduler` | | `model_params.dropout_scheduler.T` | Exponential factor. | `float` | `5e4` | diff --git a/tests/conftest.py b/tests/conftest.py index e398d9df0563fea16ae13478ec6f57b2c055abca..136cead1d494741405ee1365ce96116638e6a021 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -101,7 +101,6 @@ def training_config(): "dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers "use_2d_pe": True, # use 2D positional embedding "use_1d_pe": True, # use 1D positional embedding - "use_lstm": False, "attention_win": 100, # length of attention window # Curriculum dropout "dropout_scheduler": { diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml index dbb4bc0f0779e59f14248286ad42586fc1520345..101fe5c83603fb185878472c669e2bc36fa15eca 100644 --- a/tests/data/prediction/parameters.yml +++ b/tests/data/prediction/parameters.yml @@ -13,7 +13,6 @@ parameters: dec_pred_dropout: 0.1 attention_win: 100 use_1d_pe: True - use_lstm: False vocab_size: 96 h_max: 500 w_max: 1000