From f0a6e38ccd03e2594bdc27d59e708655b90e1eb4 Mon Sep 17 00:00:00 2001
From: manonBlanco <blanco@teklia.com>
Date: Mon, 17 Jul 2023 09:16:29 +0200
Subject: [PATCH] Always use 1D and 2D positional embedding

---
 dan/decoder.py                       | 12 ++----------
 dan/ocr/document/train.py            |  2 --
 docs/get_started/training.md         |  1 -
 docs/usage/train/parameters.md       |  2 --
 tests/conftest.py                    |  2 --
 tests/data/prediction/parameters.yml |  1 -
 6 files changed, 2 insertions(+), 18 deletions(-)

diff --git a/dan/decoder.py b/dan/decoder.py
index af5c28cb..69e372cf 100644
--- a/dan/decoder.py
+++ b/dan/decoder.py
@@ -305,14 +305,9 @@ class FeaturesUpdater(Module):
         self.pe_2d = PositionalEncoding2D(
             params["enc_dim"], params["h_max"], params["w_max"], params["device"]
         )
-        self.use_2d_positional_encoding = (
-            "use_2d_pe" not in params or params["use_2d_pe"]
-        )
 
     def get_pos_features(self, features):
-        if self.use_2d_positional_encoding:
-            return self.pe_2d(features)
-        return features
+        return self.pe_2d(features)
 
 
 class GlobalHTADecoder(Module):
@@ -326,7 +321,6 @@ class GlobalHTADecoder(Module):
         self.dec_att_win = (
             params["attention_win"] if params["attention_win"] is not None else 1
         )
-        self.use_1d_pe = "use_1d_pe" not in params or params["use_1d_pe"]
 
         self.features_updater = FeaturesUpdater(params)
         self.att_decoder = GlobalAttDecoder(params)
@@ -361,9 +355,7 @@ class GlobalHTADecoder(Module):
         pos_tokens = self.emb(tokens).permute(0, 2, 1)
 
         # Add 1D Positional Encoding
-        if self.use_1d_pe:
-            pos_tokens = self.pe_1d(pos_tokens, start=start)
-        pos_tokens = pos_tokens.permute(2, 0, 1)
+        pos_tokens = self.pe_1d(pos_tokens, start=start).permute(2, 0, 1)
 
         if num_pred is None:
             num_pred = tokens.size(1)
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index 6f5e17b3..ad4643b9 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -151,8 +151,6 @@ def get_config():
             "dec_pred_dropout": 0.1,  # dropout rate before decision layer
             "dec_att_dropout": 0.1,  # dropout rate in multi head attention
             "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
-            "use_2d_pe": True,  # use 2D positional embedding
-            "use_1d_pe": True,  # use 1D positional embedding
             "attention_win": 100,  # length of attention window
             # Curriculum dropout
             "dropout_scheduler": {
diff --git a/docs/get_started/training.md b/docs/get_started/training.md
index bc0764c3..15106d40 100644
--- a/docs/get_started/training.md
+++ b/docs/get_started/training.md
@@ -58,7 +58,6 @@ parameters:
     l_max: int
     dec_pred_dropout: float
     attention_win: int
-    use_1d_pe: bool
     vocab_size: int
     h_max: int
     w_max: int
diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md
index bc346d86..ac4026fe 100644
--- a/docs/usage/train/parameters.md
+++ b/docs/usage/train/parameters.md
@@ -136,8 +136,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
 | `model_params.dec_pred_dropout`           | Dropout rate before decision layer.                                                  | `float`       | `0.1`                                                             |
 | `model_params.dec_att_dropout`            | Dropout rate in multi head attention.                                                | `float`       | `0.1`                                                             |
 | `model_params.dec_dim_feedforward`        | Number of dimensions for feedforward layer in transformer decoder layers.            | `int`         | `256`                                                             |
-| `model_params.use_2d_pe`                  | Whether to use 2D positional embedding.                                              | `bool`        | `True`                                                            |
-| `model_params.use_1d_pe`                  | Whether to use 1D positional embedding.                                              | `bool`        | `True`                                                            |
 | `model_params.attention_win`              | Length of attention window.                                                          | `int`         | `100`                                                             |
 | `model_params.dropout_scheduler.function` | Curriculum dropout scheduler.                                                        | custom class  | `exponential_dropout_scheduler`                                   |
 | `model_params.dropout_scheduler.T`        | Exponential factor.                                                                  | `float`       | `5e4`                                                             |
diff --git a/tests/conftest.py b/tests/conftest.py
index 136cead1..e804cb36 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -99,8 +99,6 @@ def training_config():
             "dec_pred_dropout": 0.1,  # dropout rate before decision layer
             "dec_att_dropout": 0.1,  # dropout rate in multi head attention
             "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
-            "use_2d_pe": True,  # use 2D positional embedding
-            "use_1d_pe": True,  # use 1D positional embedding
             "attention_win": 100,  # length of attention window
             # Curriculum dropout
             "dropout_scheduler": {
diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml
index 469afc8d..2bad8803 100644
--- a/tests/data/prediction/parameters.yml
+++ b/tests/data/prediction/parameters.yml
@@ -10,7 +10,6 @@ parameters:
     l_max: 15000
     dec_pred_dropout: 0.1
     attention_win: 100
-    use_1d_pe: True
     vocab_size: 96
     h_max: 500
     w_max: 1000
-- 
GitLab