From 6679f473475690d158c49765593f81eb94110a6b Mon Sep 17 00:00:00 2001 From: manonBlanco <blanco@teklia.com> Date: Mon, 17 Jul 2023 09:29:44 +0200 Subject: [PATCH] Never use lstm --- dan/decoder.py | 18 +----------------- dan/ocr/document/train.py | 1 - docs/get_started/training.md | 1 - docs/usage/train/parameters.md | 1 - tests/conftest.py | 1 - tests/data/prediction/parameters.yml | 1 - 6 files changed, 1 insertion(+), 22 deletions(-) diff --git a/dan/decoder.py b/dan/decoder.py index 84518ab5..af5c28cb 100644 --- a/dan/decoder.py +++ b/dan/decoder.py @@ -2,16 +2,7 @@ import torch from torch import relu, softmax -from torch.nn import ( - LSTM, - Conv1d, - Dropout, - Embedding, - LayerNorm, - Linear, - Module, - ModuleList, -) +from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList from torch.nn.init import xavier_uniform_ @@ -336,7 +327,6 @@ class GlobalHTADecoder(Module): params["attention_win"] if params["attention_win"] is not None else 1 ) self.use_1d_pe = "use_1d_pe" not in params or params["use_1d_pe"] - self.use_lstm = params["use_lstm"] self.features_updater = FeaturesUpdater(params) self.att_decoder = GlobalAttDecoder(params) @@ -348,9 +338,6 @@ class GlobalHTADecoder(Module): params["enc_dim"], params["l_max"], params["device"] ) - if self.use_lstm: - self.lstm_predict = LSTM(params["enc_dim"], params["enc_dim"]) - vocab_size = params["vocab_size"] + 1 self.end_conv = Conv1d(params["enc_dim"], vocab_size, kernel_size=1) @@ -426,9 +413,6 @@ class GlobalHTADecoder(Module): keep_all_weights=keep_all_weights, ) - if self.use_lstm: - output, hidden_predict = self.lstm_predict(output, hidden_predict) - dp_output = self.dropout(relu(output)) preds = self.end_conv(dp_output.permute(1, 2, 0)) diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py index 63b5080a..6f5e17b3 100644 --- a/dan/ocr/document/train.py +++ b/dan/ocr/document/train.py @@ -153,7 +153,6 @@ def get_config(): "dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers "use_2d_pe": True, # use 2D positional embedding "use_1d_pe": True, # use 1D positional embedding - "use_lstm": False, "attention_win": 100, # length of attention window # Curriculum dropout "dropout_scheduler": { diff --git a/docs/get_started/training.md b/docs/get_started/training.md index 5438363e..bc0764c3 100644 --- a/docs/get_started/training.md +++ b/docs/get_started/training.md @@ -59,7 +59,6 @@ parameters: dec_pred_dropout: float attention_win: int use_1d_pe: bool - use_lstm: bool vocab_size: int h_max: int w_max: int diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md index 0e6a2b47..bc346d86 100644 --- a/docs/usage/train/parameters.md +++ b/docs/usage/train/parameters.md @@ -138,7 +138,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa | `model_params.dec_dim_feedforward` | Number of dimensions for feedforward layer in transformer decoder layers. | `int` | `256` | | `model_params.use_2d_pe` | Whether to use 2D positional embedding. | `bool` | `True` | | `model_params.use_1d_pe` | Whether to use 1D positional embedding. | `bool` | `True` | -| `model_params.use_lstm` | Whether to use a LSTM layer in the decoder. | `bool` | `False` | | `model_params.attention_win` | Length of attention window. | `int` | `100` | | `model_params.dropout_scheduler.function` | Curriculum dropout scheduler. | custom class | `exponential_dropout_scheduler` | | `model_params.dropout_scheduler.T` | Exponential factor. | `float` | `5e4` | diff --git a/tests/conftest.py b/tests/conftest.py index e398d9df..136cead1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -101,7 +101,6 @@ def training_config(): "dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers "use_2d_pe": True, # use 2D positional embedding "use_1d_pe": True, # use 1D positional embedding - "use_lstm": False, "attention_win": 100, # length of attention window # Curriculum dropout "dropout_scheduler": { diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml index bc56c1f6..469afc8d 100644 --- a/tests/data/prediction/parameters.yml +++ b/tests/data/prediction/parameters.yml @@ -11,7 +11,6 @@ parameters: dec_pred_dropout: 0.1 attention_win: 100 use_1d_pe: True - use_lstm: False vocab_size: 96 h_max: 500 w_max: 1000 -- GitLab