Skip to content
Snippets Groups Projects

Use device to disable pin_memory behaviour

Merged Yoann Schneider requested to merge disable-pin-memory-in-ci into main
3 files
+ 11
7
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 6
4
@@ -15,12 +15,14 @@ from dan.transforms import apply_data_augmentation
class DatasetManager:
def __init__(self, params):
def __init__(self, params, device: str):
self.params = params
self.dataset_class = None
self.img_padding_value = params["config"]["padding_value"]
self.my_collate_function = None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
@@ -115,7 +117,7 @@ class DatasetManager:
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=True,
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
@@ -129,7 +131,7 @@ class DatasetManager:
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=True,
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
@@ -174,7 +176,7 @@ class DatasetManager:
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=True,
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
Loading