From d5a223de1b94068bf05f32fb3063ddf6007c9dd1 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Wed, 9 Aug 2023 09:26:45 +0000
Subject: [PATCH] Always use the same dropout scheduler function

---
 dan/manager/training.py        | 16 ++++---------
 dan/ocr/document/train.py      |  9 -------
 dan/schedulers.py              |  5 ++--
 docs/usage/train/parameters.md | 44 ++++++++++++++++------------------
 tests/conftest.py              |  6 -----
 5 files changed, 28 insertions(+), 52 deletions(-)

diff --git a/dan/manager/training.py b/dan/manager/training.py
index ee944e17..ab914ead 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -32,7 +32,6 @@ class GenericTrainingManager:
         self.type = None
         self.is_master = False
         self.params = params
-        self.dropout_scheduler = None
         self.models = {}
         self.dataset = None
         self.dataset_name = list(self.params["dataset_params"]["datasets"].values())[0]
@@ -196,10 +195,7 @@ class GenericTrainingManager:
                 )
 
         # Handle curriculum dropout
-        if "dropout_scheduler" in self.params["model_params"]:
-            func = self.params["model_params"]["dropout_scheduler"]["function"]
-            T = self.params["model_params"]["dropout_scheduler"]["T"]
-            self.dropout_scheduler = DropoutScheduler(self.models, func, T)
+        self.dropout_scheduler = DropoutScheduler(self.models)
 
         self.scaler = GradScaler(enabled=self.params["training_params"]["use_amp"])
 
@@ -652,13 +648,11 @@ class GenericTrainingManager:
                                 if "lr" in metric_names:
                                     self.writer.add_scalar(
                                         "lr_{}".format(model_name),
-                                        self.lr_schedulers[model_name].lr,
-                                        self.lr_schedulers[model_name].step_num,
                                     )
-                    # Update dropout scheduler if used
-                    if self.dropout_scheduler:
-                        self.dropout_scheduler.step(len(batch_metrics["names"]))
-                        self.dropout_scheduler.update_dropout_rate()
+
+                    # Update dropout scheduler
+                    self.dropout_scheduler.step(len(batch_metrics["names"]))
+                    self.dropout_scheduler.update_dropout_rate()
 
                     # Add batch metrics values to epoch metrics values
                     self.metric_manager["train"].update_metrics(batch_metrics)
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index 712c8a34..e3aafdce 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -14,7 +14,6 @@ from dan.decoder import GlobalHTADecoder
 from dan.encoder import FCN_Encoder
 from dan.manager.training import Manager
 from dan.mlflow import MLFLOW_AVAILABLE
-from dan.schedulers import exponential_dropout_scheduler
 from dan.transforms import Preprocessing
 from dan.utils import MLflowNotInstalled
 
@@ -152,11 +151,6 @@ def get_config():
             "dec_att_dropout": 0.1,  # dropout rate in multi head attention
             "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
             "attention_win": 100,  # length of attention window
-            # Curriculum dropout
-            "dropout_scheduler": {
-                "function": exponential_dropout_scheduler,
-                "T": 5e4,
-            },
         },
         "training_params": {
             "output_folder": "outputs/dan_esposalles_record",  # folder name for checkpoint and results
@@ -238,9 +232,6 @@ def serialize_config(config):
     serialized_config["dataset_params"]["config"]["augmentation"] = str(
         serialized_config["dataset_params"]["config"]["augmentation"]
     )
-    serialized_config["model_params"]["dropout_scheduler"]["function"] = str(
-        serialized_config["model_params"]["dropout_scheduler"]["function"]
-    )
     serialized_config["training_params"]["nb_gpu"] = str(
         serialized_config["training_params"]["nb_gpu"]
     )
diff --git a/dan/schedulers.py b/dan/schedulers.py
index 752c300a..3c6ef238 100644
--- a/dan/schedulers.py
+++ b/dan/schedulers.py
@@ -4,14 +4,13 @@ from torch.nn import Dropout, Dropout2d
 
 
 class DropoutScheduler:
-    def __init__(self, models, function, T=1e5):
+    def __init__(self, models, T=5e4):
         """
         T: number of gradient updates to converge
         """
 
         self.teta_list = list()
         self.init_teta_list(models)
-        self.function = function
         self.T = T
         self.step_num = 0
 
@@ -31,7 +30,7 @@ class DropoutScheduler:
 
     def update_dropout_rate(self):
         for module, p in self.teta_list:
-            module.p = self.function(p, self.step_num, self.T)
+            module.p = exponential_dropout_scheduler(p, self.step_num, self.T)
 
 
 def exponential_dropout_scheduler(dropout_rate, step, max_step):
diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md
index b031dab9..867bce77 100644
--- a/docs/usage/train/parameters.md
+++ b/docs/usage/train/parameters.md
@@ -124,29 +124,27 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
 
 ## Model parameters
 
-| Name                                      | Description                                                                            | Type         | Default                                                           |
-| ----------------------------------------- | -------------------------------------------------------------------------------------- | ------------ | ----------------------------------------------------------------- |
-| `model_params.models.encoder`             | Encoder class.                                                                         | custom class | `FCN_encoder`                                                     |
-| `model_params.models.decoder`             | Decoder class.                                                                         | custom class | `GlobalHTADecoder`                                                |
-| `model_params.transfer_learning.encoder`  | Model to load for the encoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list`       | `["encoder", "pretrained_models/dan_rimes_page.pt", True, True]`  |
-| `model_params.transfer_learning.decoder`  | Model to load for the decoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list`       | `["encoder", "pretrained_models/dan_rimes_page.pt", True, False]` |
-| `model_params.transfered_charset`         | Transfer learning of the decision layer based on charset of the model to transfer.     | `bool`       | `True`                                                            |
-| `model_params.additional_tokens`          | For decision layer = \[<eot>, \], only for transferred charset.                        | `int`        | `1`                                                               |
-| `model_params.dropout`                    | Dropout probability in the encoder.                                                    | `float`      | `0.5`                                                             |
-| `model_params.enc_dim`                    | Dimension of features extracted by the encoder.                                        | `int`        | `256`                                                             |
-| `model_params.nb_layers`                  | Number of layers in the encoder.                                                       | `int`        | `5`                                                               |
-| `model_params.h_max`                      | Maximum height for encoder output (for 2D positional embedding).                       | `int`        | `500`                                                             |
-| `model_params.w_max`                      | Maximum width for encoder output (for 2D positional embedding).                        | `int`        | `1000`                                                            |
-| `model_params.l_max`                      | Maximum predicted sequence length (for 1D positional embedding).                       | `int`        | `15000`                                                           |
-| `model_params.dec_num_layers`             | Number of transformer decoder layers.                                                  | `int`        | `8`                                                               |
-| `model_params.dec_num_heads`              | Number of heads in transformer decoder layers.                                         | `int`        | `4`                                                               |
-| `model_params.dec_res_dropout`            | Dropout probability in transformer decoder layers.                                     | `int`        | `0.1`                                                             |
-| `model_params.dec_pred_dropout`           | Dropout rate before decision layer.                                                    | `float`      | `0.1`                                                             |
-| `model_params.dec_att_dropout`            | Dropout rate in multi head attention.                                                  | `float`      | `0.1`                                                             |
-| `model_params.dec_dim_feedforward`        | Number of dimensions for feedforward layer in transformer decoder layers.              | `int`        | `256`                                                             |
-| `model_params.attention_win`              | Length of attention window.                                                            | `int`        | `100`                                                             |
-| `model_params.dropout_scheduler.function` | Curriculum dropout scheduler.                                                          | custom class | `exponential_dropout_scheduler`                                   |
-| `model_params.dropout_scheduler.T`        | Exponential factor.                                                                    | `float`      | `5e4`                                                             |
+| Name                                     | Description                                                                            | Type         | Default                                                           |
+| ---------------------------------------- | -------------------------------------------------------------------------------------- | ------------ | ----------------------------------------------------------------- |
+| `model_params.models.encoder`            | Encoder class.                                                                         | custom class | `FCN_encoder`                                                     |
+| `model_params.models.decoder`            | Decoder class.                                                                         | custom class | `GlobalHTADecoder`                                                |
+| `model_params.transfer_learning.encoder` | Model to load for the encoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list`       | `["encoder", "pretrained_models/dan_rimes_page.pt", True, True]`  |
+| `model_params.transfer_learning.decoder` | Model to load for the decoder \[state_dict_name, checkpoint_path, learnable, strict\]. | `list`       | `["encoder", "pretrained_models/dan_rimes_page.pt", True, False]` |
+| `model_params.transfered_charset`        | Transfer learning of the decision layer based on charset of the model to transfer.     | `bool`       | `True`                                                            |
+| `model_params.additional_tokens`         | For decision layer = \[<eot>, \], only for transferred charset.                        | `int`        | `1`                                                               |
+| `model_params.dropout`                   | Dropout probability in the encoder.                                                    | `float`      | `0.5`                                                             |
+| `model_params.enc_dim`                   | Dimension of features extracted by the encoder.                                        | `int`        | `256`                                                             |
+| `model_params.nb_layers`                 | Number of layers in the encoder.                                                       | `int`        | `5`                                                               |
+| `model_params.h_max`                     | Maximum height for encoder output (for 2D positional embedding).                       | `int`        | `500`                                                             |
+| `model_params.w_max`                     | Maximum width for encoder output (for 2D positional embedding).                        | `int`        | `1000`                                                            |
+| `model_params.l_max`                     | Maximum predicted sequence length (for 1D positional embedding).                       | `int`        | `15000`                                                           |
+| `model_params.dec_num_layers`            | Number of transformer decoder layers.                                                  | `int`        | `8`                                                               |
+| `model_params.dec_num_heads`             | Number of heads in transformer decoder layers.                                         | `int`        | `4`                                                               |
+| `model_params.dec_res_dropout`           | Dropout probability in transformer decoder layers.                                     | `int`        | `0.1`                                                             |
+| `model_params.dec_pred_dropout`          | Dropout rate before decision layer.                                                    | `float`      | `0.1`                                                             |
+| `model_params.dec_att_dropout`           | Dropout rate in multi head attention.                                                  | `float`      | `0.1`                                                             |
+| `model_params.dec_dim_feedforward`       | Number of dimensions for feedforward layer in transformer decoder layers.              | `int`        | `256`                                                             |
+| `model_params.attention_win`             | Length of attention window.                                                            | `int`        | `100`                                                             |
 
 ## Training parameters
 
diff --git a/tests/conftest.py b/tests/conftest.py
index c0bcfb8c..7ce04118 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,7 +7,6 @@ from torch.optim import Adam
 from arkindex_export import open_database
 from dan.decoder import GlobalHTADecoder
 from dan.encoder import FCN_Encoder
-from dan.schedulers import exponential_dropout_scheduler
 from dan.transforms import Preprocessing
 
 FIXTURES = Path(__file__).resolve().parent / "data"
@@ -83,11 +82,6 @@ def training_config():
             "dec_att_dropout": 0.1,  # dropout rate in multi head attention
             "dec_dim_feedforward": 256,  # number of dimension for feedforward layer in transformer decoder layers
             "attention_win": 100,  # length of attention window
-            # Curriculum dropout
-            "dropout_scheduler": {
-                "function": exponential_dropout_scheduler,
-                "T": 5e4,
-            },
         },
         "training_params": {
             "output_folder": "dan_trained_model",  # folder name for checkpoint and results
-- 
GitLab