Skip to content
Snippets Groups Projects
Verified Commit 46de9ddf authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

depend on torch cuda status to use pin_memory

parent e5b443b3
No related branches found
No related tags found
No related merge requests found
......@@ -37,6 +37,8 @@ class DatasetManager:
self.generator = torch.Generator()
self.generator.manual_seed(0)
self.pin_memory_enabled = torch.cuda.is_available()
self.batch_size = {
"train": self.params["batch_size"],
"val": self.params["valid_batch_size"]
......@@ -115,7 +117,7 @@ class DatasetManager:
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=False,
pin_memory=self.pin_memory_enabled,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
......@@ -129,7 +131,7 @@ class DatasetManager:
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=False,
pin_memory=self.pin_memory_enabled,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
......@@ -174,7 +176,7 @@ class DatasetManager:
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=False,
pin_memory=self.pin_memory_enabled,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment