From 46de9ddf5e216e89b001498c446526f182fe5ea6 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Wed, 24 May 2023 09:31:00 +0200
Subject: [PATCH] depend on torch cuda status to use pin_memory

---
 dan/manager/dataset.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py
index f0e9c8df..4ca9af2f 100644
--- a/dan/manager/dataset.py
+++ b/dan/manager/dataset.py
@@ -37,6 +37,8 @@ class DatasetManager:
         self.generator = torch.Generator()
         self.generator.manual_seed(0)
 
+        self.pin_memory_enabled = torch.cuda.is_available()
+
         self.batch_size = {
             "train": self.params["batch_size"],
             "val": self.params["valid_batch_size"]
@@ -115,7 +117,7 @@ class DatasetManager:
             batch_sampler=self.train_sampler,
             sampler=self.train_sampler,
             num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-            pin_memory=False,
+            pin_memory=self.pin_memory_enabled,
             collate_fn=self.my_collate_function,
             worker_init_fn=self.seed_worker,
             generator=self.generator,
@@ -129,7 +131,7 @@ class DatasetManager:
                 batch_sampler=self.valid_samplers[key],
                 shuffle=False,
                 num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-                pin_memory=False,
+                pin_memory=self.pin_memory_enabled,
                 drop_last=False,
                 collate_fn=self.my_collate_function,
                 worker_init_fn=self.seed_worker,
@@ -174,7 +176,7 @@ class DatasetManager:
             sampler=self.test_samplers[custom_name],
             shuffle=False,
             num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-            pin_memory=False,
+            pin_memory=self.pin_memory_enabled,
             drop_last=False,
             collate_fn=self.my_collate_function,
             worker_init_fn=self.seed_worker,
-- 
GitLab