From 8675a2109b76a172e3d2e2964ec5b44762e60ae7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Mon, 7 Aug 2023 08:22:57 +0000
Subject: [PATCH] Support DistributedDataParallel

---
 dan/manager/ocr.py             |  2 --
 dan/manager/training.py        | 62 ++++++++++++++++++++++++++--------
 dan/utils.py                   | 15 ++++++++
 docs/usage/train/jeanzay.md    |  8 +++++
 docs/usage/train/parameters.md |  5 +--
 5 files changed, 73 insertions(+), 19 deletions(-)

diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index 077f965c..8e18b813 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -127,7 +127,6 @@ class OCRDatasetManager:
             batch_size=self.params["batch_size"],
             shuffle=True if self.train_sampler is None else False,
             drop_last=False,
-            batch_sampler=self.train_sampler,
             sampler=self.train_sampler,
             num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
             pin_memory=self.pin_memory,
@@ -141,7 +140,6 @@ class OCRDatasetManager:
                 self.valid_datasets[key],
                 batch_size=1,
                 sampler=self.valid_samplers[key],
-                batch_sampler=self.valid_samplers[key],
                 shuffle=False,
                 num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
                 pin_memory=self.pin_memory,
diff --git a/dan/manager/training.py b/dan/manager/training.py
index af6156cc..fdaf5d3d 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -21,7 +21,7 @@ from dan.manager.metrics import MetricManager
 from dan.manager.ocr import OCRDatasetManager
 from dan.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
 from dan.schedulers import DropoutScheduler
-from dan.utils import ind_to_token
+from dan.utils import fix_ddp_layers_names, ind_to_token
 
 if MLFLOW_AVAILABLE:
     import mlflow
@@ -55,6 +55,11 @@ class GenericTrainingManager:
         self.params["model_params"]["use_amp"] = self.params["training_params"][
             "use_amp"
         ]
+        self.nb_gpu = (
+            self.params["training_params"]["nb_gpu"]
+            if self.params["training_params"]["use_ddp"]
+            else 1
+        )
 
     def init_paths(self):
         """
@@ -184,7 +189,9 @@ class GenericTrainingManager:
             # make the model compatible with Distributed Data Parallel if used
             if self.params["training_params"]["use_ddp"]:
                 self.models[model_name] = DDP(
-                    self.models[model_name], [self.ddp_config["rank"]]
+                    self.models[model_name],
+                    [self.ddp_config["rank"]],
+                    output_device=self.ddp_config["rank"],
                 )
 
         # Handle curriculum dropout
@@ -214,7 +221,10 @@ class GenericTrainingManager:
         if self.params["training_params"]["load_epoch"] in ("best", "last"):
             for filename in os.listdir(self.paths["checkpoints"]):
                 if self.params["training_params"]["load_epoch"] in filename:
-                    return torch.load(os.path.join(self.paths["checkpoints"], filename))
+                    return torch.load(
+                        os.path.join(self.paths["checkpoints"], filename),
+                        map_location=self.device,
+                    )
         return None
 
     def load_existing_model(self, checkpoint, strict=True):
@@ -230,8 +240,14 @@ class GenericTrainingManager:
             self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
         # Load model weights from past training
         for model_name in self.models.keys():
+            # Transform to DDP/from DDP model
+            checkpoint[f"{model_name}_state_dict"] = fix_ddp_layers_names(
+                checkpoint[f"{model_name}_state_dict"],
+                self.params["training_params"]["use_ddp"],
+            )
+
             self.models[model_name].load_state_dict(
-                checkpoint["{}_state_dict".format(model_name)], strict=strict
+                checkpoint[f"{model_name}_state_dict"], strict=strict
             )
 
     def init_new_model(self):
@@ -252,8 +268,15 @@ class GenericTrainingManager:
                 state_dict_name, path, learnable, strict = self.params["model_params"][
                     "transfer_learning"
                 ][model_name]
+
                 # Loading pretrained weights file
-                checkpoint = torch.load(path)
+                checkpoint = torch.load(path, map_location=self.device)
+                # Transform to DDP/from DDP model
+                checkpoint[f"{model_name}_state_dict"] = fix_ddp_layers_names(
+                    checkpoint[f"{model_name}_state_dict"],
+                    self.params["training_params"]["use_ddp"],
+                )
+
                 try:
                     # Load pretrained weights for model
                     self.models[model_name].load_state_dict(
@@ -595,7 +618,6 @@ class GenericTrainingManager:
             self.metric_manager["train"] = MetricManager(
                 metric_names=metric_names, dataset_name=self.dataset_name
             )
-
             with tqdm(total=len(self.dataset.train_loader.dataset)) as pbar:
                 pbar.set_description("EPOCH {}/{}".format(num_epoch, nb_epochs))
                 # iterates over mini-batch data
@@ -644,7 +666,7 @@ class GenericTrainingManager:
                     self.metric_manager["train"].update_metrics(batch_metrics)
                     display_values = self.metric_manager["train"].get_display_values()
                     pbar.set_postfix(values=str(display_values))
-                    pbar.update(len(batch_data["names"]))
+                    pbar.update(len(batch_data["names"]) * self.nb_gpu)
 
                 # Log MLflow metrics
                 logging_metrics(
@@ -737,7 +759,7 @@ class GenericTrainingManager:
                     display_values = self.metric_manager[set_name].get_display_values()
 
                     pbar.set_postfix(values=str(display_values))
-                    pbar.update(len(batch_data["names"]))
+                    pbar.update(len(batch_data["names"]) * self.nb_gpu)
 
                 # log metrics in MLflow
                 logging_metrics(
@@ -789,7 +811,7 @@ class GenericTrainingManager:
                     ].get_display_values()
 
                     pbar.set_postfix(values=str(display_values))
-                    pbar.update(len(batch_data["names"]))
+                    pbar.update(len(batch_data["names"]) * self.nb_gpu)
 
                 # log metrics in MLflow
                 logging_name = custom_name.split("-")[1]
@@ -991,9 +1013,14 @@ class Manager(OCRManager):
             features_size = raw_features.size()
             b, c, h, w = features_size
 
-            pos_features = self.models["decoder"].features_updater.get_pos_features(
-                raw_features
-            )
+            if self.params["training_params"]["use_ddp"]:
+                pos_features = self.models[
+                    "decoder"
+                ].module.features_updater.get_pos_features(raw_features)
+            else:
+                pos_features = self.models["decoder"].features_updater.get_pos_features(
+                    raw_features
+                )
             features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
                 2, 0, 1
             )
@@ -1086,9 +1113,14 @@ class Manager(OCRManager):
             else:
                 features = self.models["encoder"](x)
             features_size = features.size()
-            pos_features = self.models["decoder"].features_updater.get_pos_features(
-                features
-            )
+            if self.params["training_params"]["use_ddp"]:
+                pos_features = self.models[
+                    "decoder"
+                ].module.features_updater.get_pos_features(features)
+            else:
+                pos_features = self.models["decoder"].features_updater.get_pos_features(
+                    features
+                )
             features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
                 2, 0, 1
             )
diff --git a/dan/utils.py b/dan/utils.py
index 0e39aba6..84e73824 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -72,6 +72,21 @@ def ind_to_token(labels, ind, oov_symbol=None):
     return "".join(res)
 
 
+def fix_ddp_layers_names(model, to_ddp):
+    """
+    Rename the model layers if they were saved using DDP or if they will
+    be used with DDP.
+    :param model: Model to update.
+    :param to_ddp: Convert layers names to be used by DDP.
+    :return: The model with corrected layers names.
+    """
+    if to_ddp:
+        return {
+            ("module." + k if "module" not in k else k): v for k, v in model.items()
+        }
+    return {k.replace("module.", ""): v for k, v in model.items()}
+
+
 def list_to_batches(iterable, n):
     "Batch data into tuples of length n. The last batch may be shorter."
     # list_to_batches('ABCDEFG', 3) --> ABC DEF G
diff --git a/docs/usage/train/jeanzay.md b/docs/usage/train/jeanzay.md
index f486b76d..f7804274 100644
--- a/docs/usage/train/jeanzay.md
+++ b/docs/usage/train/jeanzay.md
@@ -41,6 +41,14 @@ set -x
 teklia-dan train document
 ```
 
+## Train on multiple GPUs
+
+To train on multiple GPUs, one needs to update the parameters in the training configuration file, as detailed in the [dedicated section](parameters.md#training-parameters). In addition, the number of GPUs required must be specified in the `train_dan.sh` file by updating the following line:
+
+```sh
+#SBATCH --gres=gpu:<nb_gpus>            # number of GPUs per node
+```
+
 ## Supervise a job
 
 - Use `squeue -u $USER`. This command should give an output similar to the one presented below.
diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md
index c4a97265..731e970d 100644
--- a/docs/usage/train/parameters.md
+++ b/docs/usage/train/parameters.md
@@ -1,6 +1,6 @@
 # Training configuration
 
-All hyperparameters are specified and editable in the training scripts `dan/ocr/document/train.py::get_config` (descriptions are in comments). This page introduces some useful keys and theirs descriptions.
+All hyperparameters are specified and editable in the training scripts `dan/ocr/document/train.py::get_config` (descriptions are in comments). This page introduces some useful parameters and theirs descriptions.
 
 ## Dataset parameters
 
@@ -174,7 +174,8 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
 | `training_params.label_noise_scheduler.max_error_rate`  | Maximum ratio of teacher forcing.                                           | `float`      | `0.2`                                       |
 | `training_params.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing.                            | `float`      | `5e4`                                       |
 
-During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations.
+- To train on several GPUs, simply set the `training_params.use_ddp` parameter to `True`. By default, the model will use all available GPUs. To restrict access to fewer GPUs, one can modify the `training_params.nb_gpu` parameter.
+- During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations.
 
 ## MLFlow logging
 
-- 
GitLab