diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py index f0e9c8dfec90e6878aa3734bb7bcd74f8f34cd09..4ca9af2f8b1046318f13f9a7b46bb9bbc935e4b5 100644 --- a/dan/manager/dataset.py +++ b/dan/manager/dataset.py @@ -37,6 +37,8 @@ class DatasetManager: self.generator = torch.Generator() self.generator.manual_seed(0) + self.pin_memory_enabled = torch.cuda.is_available() + self.batch_size = { "train": self.params["batch_size"], "val": self.params["valid_batch_size"] @@ -115,7 +117,7 @@ class DatasetManager: batch_sampler=self.train_sampler, sampler=self.train_sampler, num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=False, + pin_memory=self.pin_memory_enabled, collate_fn=self.my_collate_function, worker_init_fn=self.seed_worker, generator=self.generator, @@ -129,7 +131,7 @@ class DatasetManager: batch_sampler=self.valid_samplers[key], shuffle=False, num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=False, + pin_memory=self.pin_memory_enabled, drop_last=False, collate_fn=self.my_collate_function, worker_init_fn=self.seed_worker, @@ -174,7 +176,7 @@ class DatasetManager: sampler=self.test_samplers[custom_name], shuffle=False, num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=False, + pin_memory=self.pin_memory_enabled, drop_last=False, collate_fn=self.my_collate_function, worker_init_fn=self.seed_worker,