Skip to content
Snippets Groups Projects
Verified Commit 46de9ddf authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

depend on torch cuda status to use pin_memory

parent e5b443b3
No related branches found
No related tags found
No related merge requests found
...@@ -37,6 +37,8 @@ class DatasetManager: ...@@ -37,6 +37,8 @@ class DatasetManager:
self.generator = torch.Generator() self.generator = torch.Generator()
self.generator.manual_seed(0) self.generator.manual_seed(0)
self.pin_memory_enabled = torch.cuda.is_available()
self.batch_size = { self.batch_size = {
"train": self.params["batch_size"], "train": self.params["batch_size"],
"val": self.params["valid_batch_size"] "val": self.params["valid_batch_size"]
...@@ -115,7 +117,7 @@ class DatasetManager: ...@@ -115,7 +117,7 @@ class DatasetManager:
batch_sampler=self.train_sampler, batch_sampler=self.train_sampler,
sampler=self.train_sampler, sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=False, pin_memory=self.pin_memory_enabled,
collate_fn=self.my_collate_function, collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker, worker_init_fn=self.seed_worker,
generator=self.generator, generator=self.generator,
...@@ -129,7 +131,7 @@ class DatasetManager: ...@@ -129,7 +131,7 @@ class DatasetManager:
batch_sampler=self.valid_samplers[key], batch_sampler=self.valid_samplers[key],
shuffle=False, shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=False, pin_memory=self.pin_memory_enabled,
drop_last=False, drop_last=False,
collate_fn=self.my_collate_function, collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker, worker_init_fn=self.seed_worker,
...@@ -174,7 +176,7 @@ class DatasetManager: ...@@ -174,7 +176,7 @@ class DatasetManager:
sampler=self.test_samplers[custom_name], sampler=self.test_samplers[custom_name],
shuffle=False, shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=False, pin_memory=self.pin_memory_enabled,
drop_last=False, drop_last=False,
collate_fn=self.my_collate_function, collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker, worker_init_fn=self.seed_worker,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment