From d2760687177148f572ea03bc7dbb79b2fda3160a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Mon, 10 Jul 2023 14:48:01 +0000
Subject: [PATCH] Fix valid batch size to 1

---
 dan/manager/ocr.py             | 22 +++-------------------
 dan/manager/training.py        |  8 --------
 dan/ocr/document/train.py      |  1 -
 docs/usage/train/parameters.md |  2 +-
 tests/conftest.py              |  1 -
 5 files changed, 4 insertions(+), 30 deletions(-)

diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index 383ac455..39dc4811 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -39,8 +39,6 @@ class OCRDatasetManager:
         self.generator = torch.Generator()
         self.generator.manual_seed(0)
 
-        self.batch_size = self.get_batch_size_by_set()
-
         self.load_in_memory = (
             self.params["config"]["load_in_memory"]
             if "load_in_memory" in self.params["config"]
@@ -116,7 +114,7 @@ class OCRDatasetManager:
         """
         self.train_loader = DataLoader(
             self.train_dataset,
-            batch_size=self.batch_size["train"],
+            batch_size=self.params["batch_size"],
             shuffle=True if self.train_sampler is None else False,
             drop_last=False,
             batch_sampler=self.train_sampler,
@@ -131,7 +129,7 @@ class OCRDatasetManager:
         for key in self.valid_datasets.keys():
             self.valid_loaders[key] = DataLoader(
                 self.valid_datasets[key],
-                batch_size=self.batch_size["val"],
+                batch_size=1,
                 sampler=self.valid_samplers[key],
                 batch_sampler=self.valid_samplers[key],
                 shuffle=False,
@@ -185,7 +183,7 @@ class OCRDatasetManager:
             self.test_samplers[custom_name] = None
         self.test_loaders[custom_name] = DataLoader(
             self.test_datasets[custom_name],
-            batch_size=self.batch_size["test"],
+            batch_size=1,
             sampler=self.test_samplers[custom_name],
             shuffle=False,
             num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
@@ -231,20 +229,6 @@ class OCRDatasetManager:
             "pad": len(self.charset) + 2,
         }
 
-    def get_batch_size_by_set(self):
-        """
-        Return batch size for each set
-        """
-        return {
-            "train": self.params["batch_size"],
-            "val": self.params["valid_batch_size"]
-            if "valid_batch_size" in self.params
-            else self.params["batch_size"],
-            "test": self.params["test_batch_size"]
-            if "test_batch_size" in self.params
-            else 1,
-        }
-
 
 class OCRCollateFunction:
     """
diff --git a/dan/manager/training.py b/dan/manager/training.py
index 3b472a2b..f231e372 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -84,14 +84,6 @@ class GenericTrainingManager:
         self.params["dataset_params"]["batch_size"] = self.params["training_params"][
             "batch_size"
         ]
-        if "valid_batch_size" in self.params["training_params"]:
-            self.params["dataset_params"]["valid_batch_size"] = self.params[
-                "training_params"
-            ]["valid_batch_size"]
-        if "test_batch_size" in self.params["training_params"]:
-            self.params["dataset_params"]["test_batch_size"] = self.params[
-                "training_params"
-            ]["test_batch_size"]
         self.params["dataset_params"]["num_gpu"] = self.params["training_params"][
             "nb_gpu"
         ]
diff --git a/dan/ocr/document/train.py b/dan/ocr/document/train.py
index 189bc2b9..7ba39de3 100644
--- a/dan/ocr/document/train.py
+++ b/dan/ocr/document/train.py
@@ -170,7 +170,6 @@ def get_config():
             "load_epoch": "last",  # ["best", "last"]: last to continue training, best to evaluate
             "interval_save_weights": None,  # None: keep best and last only
             "batch_size": 2,  # mini-batch size for training
-            "valid_batch_size": 4,  # mini-batch size for valdiation
             "use_ddp": False,  # Use DistributedDataParallel
             "ddp_port": "20027",
             "use_amp": True,  # Enable automatic mix-precision
diff --git a/docs/usage/train/parameters.md b/docs/usage/train/parameters.md
index 562966dc..5e2b8b29 100644
--- a/docs/usage/train/parameters.md
+++ b/docs/usage/train/parameters.md
@@ -153,7 +153,6 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
 | `training_params.load_epoch`                            | Model to load. Should be either `"best"` (evaluation) or `last` (training). | `str`        | `"last"`                                    |
 | `training_params.interval_save_weights`                 | Step to save weights. Set to `None` to keep only best and last epochs.      | `int`        | `None`                                      |
 | `training_params.batch_size`                            | Mini-batch size for the training loop.                                      | `int`        | `2`                                         |
-| `training_params.valid_batch_size`                      | Mini-batch size for the valdiation loop.                                    | `int`        | `4`                                         |
 | `training_params.use_ddp`                               | Whether to use DistributedDataParallel.                                     | `bool`       | `False`                                     |
 | `training_params.ddp_port`                              | DDP port.                                                                   | `int`        | `20027`                                     |
 | `training_params.use_amp`                               | Whether to enable automatic mix-precision.                                  | `int`        | `torch.cuda.device_count()`                 |
@@ -175,6 +174,7 @@ For a detailed description of all augmentation transforms, see the [dedicated pa
 | `training_params.label_noise_scheduler.max_error_rate`  | Maximum ratio of teacher forcing.                                           | `float`      | `0.2`                                       |
 | `training_params.label_noise_scheduler.total_num_steps` | Number of steps before stopping teacher forcing.                            | `float`      | `5e4`                                       |
 
+During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations.
 
 ## MLFlow logging
 
diff --git a/tests/conftest.py b/tests/conftest.py
index a38d6a01..e398d9df 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -116,7 +116,6 @@ def training_config():
             "load_epoch": "last",  # ["best", "last"]: last to continue training, best to evaluate
             "interval_save_weights": None,  # None: keep best and last only
             "batch_size": 2,  # mini-batch size for training
-            "valid_batch_size": 2,  # mini-batch size for valdiation
             "use_ddp": False,  # Use DistributedDataParallel
             "nb_gpu": 0,
             "optimizers": {
-- 
GitLab