From 6679f473475690d158c49765593f81eb94110a6b Mon Sep 17 00:00:00 2001
From: manonBlanco <blanco@teklia.com>
Date: Mon, 17 Jul 2023 09:29:44 +0200
Subject: [PATCH] Never use lstm

---
 dan/decoder.py                       | 18 +-----------------
 dan/ocr/document/train.py            |  1 -
 docs/get_started/training.md         |  1 -
 docs/usage/train/parameters.md       |  1 -
 tests/conftest.py                    |  1 -
 tests/data/prediction/parameters.yml |  1 -
 6 files changed, 1 insertion(+), 22 deletions(-)

diff --git a/dan/decoder.py b/dan/decoder.py
index 84518ab5..af5c28cb 100644
--- a/dan/decoder.py
+++ b/dan/decoder.py
@@ -2,16 +2,7 @@
 
 import torch
 from torch import relu, softmax
-from torch.nn import (
-    LSTM,
-    Conv1d,
-    Dropout,
-    Embedding,
-    LayerNorm,
-    Linear,
-    Module,
-    ModuleList,
-)
+from torch.nn import Conv1d, Dropout, Embedding, LayerNorm, Linear, Module, ModuleList
 from torch.nn.init import xavier_uniform_
 
 
@@ -336,7 +327,6 @@ class GlobalHTADecoder(Module):
             params["attention_win"] if params["attention_win"] is not None else 1
         )
         self.use_1d_pe = "use_1d_pe" not in params or params["use_1d_pe"]
-        self.use_lstm = params["use_lstm"]
 
         self.features_updater = FeaturesUpdater(params)
         self.att_decoder = GlobalAttDecoder(params)
@@ -348,9 +338,6 @@ class GlobalHTADecoder(Module):
             params["enc_dim"], params["l_max"], params["device"]
         )
 
-        if self.use_lstm:
-            self.lstm_predict = LSTM(params["enc_dim"], params["enc_dim"])
-
         vocab_size = params["vocab_size"] + 1
         self.end_conv = Conv1d(params["enc_dim"], vocab_size, kernel_size=1)
 
@@ -426,9 +413,6 @@ class GlobalHTADecoder(Module):
             keep_all_weights=keep_all_weights,
         )
 
-        if self.use_lstm:
-            output, hidden_predict = self.lstm_predict(output, hidden_predict)
-
         dp_output = self.dropout(relu(output))
         preds = self.end_conv(dp_output.permute(1, 2, 0))
 
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index 63b5080a..6f5e17b3 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -153,7 +153,6 @@ def get_config():
             "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
             "use_2d_pe": True,  # use 2D positional embedding
             "use_1d_pe": True,  # use 1D positional embedding
-            "use_lstm": False,
             "attention_win": 100,  # length of attention window
             # Curriculum dropout
             "dropout_scheduler": {
diff --git a/docs/get_started/training.md b/docs/get_started/training.md
index 5438363e..bc0764c3 100644
--- a/docs/get_started/training.md
+++ b/docs/get_started/training.md
@@ -59,7 +59,6 @@ parameters:
     dec_pred_dropout: float
     attention_win: int
     use_1d_pe: bool
-    use_lstm: bool
     vocab_size: int
     h_max: int
     w_max: int
diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md
index 0e6a2b47..bc346d86 100644
--- a/docs/usage/train/parameters.md
+++ b/docs/usage/train/parameters.md
@@ -138,7 +138,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
 | `model_params.dec_dim_feedforward`        | Number of dimensions for feedforward layer in transformer decoder layers.            | `int`         | `256`                                                             |
 | `model_params.use_2d_pe`                  | Whether to use 2D positional embedding.                                              | `bool`        | `True`                                                            |
 | `model_params.use_1d_pe`                  | Whether to use 1D positional embedding.                                              | `bool`        | `True`                                                            |
-| `model_params.use_lstm`                   | Whether to use a LSTM layer in the decoder.                                          | `bool`        | `False`                                                           |
 | `model_params.attention_win`              | Length of attention window.                                                          | `int`         | `100`                                                             |
 | `model_params.dropout_scheduler.function` | Curriculum dropout scheduler.                                                        | custom class  | `exponential_dropout_scheduler`                                   |
 | `model_params.dropout_scheduler.T`        | Exponential factor.                                                                  | `float`       | `5e4`                                                             |
diff --git a/tests/conftest.py b/tests/conftest.py
index e398d9df..136cead1 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -101,7 +101,6 @@ def training_config():
             "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
             "use_2d_pe": True,  # use 2D positional embedding
             "use_1d_pe": True,  # use 1D positional embedding
-            "use_lstm": False,
             "attention_win": 100,  # length of attention window
             # Curriculum dropout
             "dropout_scheduler": {
diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml
index bc56c1f6..469afc8d 100644
--- a/tests/data/prediction/parameters.yml
+++ b/tests/data/prediction/parameters.yml
@@ -11,7 +11,6 @@ parameters:
     dec_pred_dropout: 0.1
     attention_win: 100
     use_1d_pe: True
-    use_lstm: False
     vocab_size: 96
     h_max: 500
     w_max: 1000
-- 
GitLab