From 66333f6a5f697c05464e69b4f3dd618d6cd82474 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Thu, 25 May 2023 08:58:37 +0000
Subject: [PATCH] Use device to disable pin_memory behaviour

---
 dan/manager/dataset.py  | 10 ++++++----
 dan/manager/ocr.py      |  4 ++--
 dan/manager/training.py |  4 +++-
 3 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py
index 2b28e6dc..41f9c433 100644
--- a/dan/manager/dataset.py
+++ b/dan/manager/dataset.py
@@ -15,12 +15,14 @@ from dan.transforms import apply_data_augmentation
 
 
 class DatasetManager:
-    def __init__(self, params):
+    def __init__(self, params, device: str):
         self.params = params
         self.dataset_class = None
         self.img_padding_value = params["config"]["padding_value"]
 
         self.my_collate_function = None
+        # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
+        self.pin_memory = device != "cpu"
 
         self.train_dataset = None
         self.valid_datasets = dict()
@@ -115,7 +117,7 @@ class DatasetManager:
             batch_sampler=self.train_sampler,
             sampler=self.train_sampler,
             num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-            pin_memory=True,
+            pin_memory=self.pin_memory,
             collate_fn=self.my_collate_function,
             worker_init_fn=self.seed_worker,
             generator=self.generator,
@@ -129,7 +131,7 @@ class DatasetManager:
                 batch_sampler=self.valid_samplers[key],
                 shuffle=False,
                 num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-                pin_memory=True,
+                pin_memory=self.pin_memory,
                 drop_last=False,
                 collate_fn=self.my_collate_function,
                 worker_init_fn=self.seed_worker,
@@ -174,7 +176,7 @@ class DatasetManager:
             sampler=self.test_samplers[custom_name],
             shuffle=False,
             num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-            pin_memory=True,
+            pin_memory=self.pin_memory,
             drop_last=False,
             collate_fn=self.my_collate_function,
             worker_init_fn=self.seed_worker,
diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index 45d851c3..2fd46fb8 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -27,8 +27,8 @@ class OCRDatasetManager(DatasetManager):
     Specific class to handle OCR/HTR tasks
     """
 
-    def __init__(self, params):
-        super(OCRDatasetManager, self).__init__(params)
+    def __init__(self, params, device: str):
+        super(OCRDatasetManager, self).__init__(params, device)
 
         self.dataset_class = OCRDataset
         self.charset = (
diff --git a/dan/manager/training.py b/dan/manager/training.py
index 53012a7c..52380fe8 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -116,7 +116,9 @@ class GenericTrainingManager:
             if "worker_per_gpu" not in self.params["dataset_params"]
             else self.params["dataset_params"]["worker_per_gpu"]
         )
-        self.dataset = OCRDatasetManager(self.params["dataset_params"])
+        self.dataset = OCRDatasetManager(
+            self.params["dataset_params"], device=self.device
+        )
         self.dataset.load_datasets()
         self.dataset.load_ddp_samplers()
         self.dataset.load_dataloaders()
-- 
GitLab