Skip to content
Snippets Groups Projects
Commit 66333f6a authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Use device to disable pin_memory behaviour

parent 04814ad8
No related branches found
No related tags found
1 merge request!138Use device to disable pin_memory behaviour
......@@ -15,12 +15,14 @@ from dan.transforms import apply_data_augmentation
class DatasetManager:
def __init__(self, params):
def __init__(self, params, device: str):
self.params = params
self.dataset_class = None
self.img_padding_value = params["config"]["padding_value"]
self.my_collate_function = None
# Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
self.pin_memory = device != "cpu"
self.train_dataset = None
self.valid_datasets = dict()
......@@ -115,7 +117,7 @@ class DatasetManager:
batch_sampler=self.train_sampler,
sampler=self.train_sampler,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=True,
pin_memory=self.pin_memory,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
generator=self.generator,
......@@ -129,7 +131,7 @@ class DatasetManager:
batch_sampler=self.valid_samplers[key],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=True,
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
......@@ -174,7 +176,7 @@ class DatasetManager:
sampler=self.test_samplers[custom_name],
shuffle=False,
num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
pin_memory=True,
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=self.my_collate_function,
worker_init_fn=self.seed_worker,
......
......@@ -27,8 +27,8 @@ class OCRDatasetManager(DatasetManager):
Specific class to handle OCR/HTR tasks
"""
def __init__(self, params):
super(OCRDatasetManager, self).__init__(params)
def __init__(self, params, device: str):
super(OCRDatasetManager, self).__init__(params, device)
self.dataset_class = OCRDataset
self.charset = (
......
......@@ -116,7 +116,9 @@ class GenericTrainingManager:
if "worker_per_gpu" not in self.params["dataset_params"]
else self.params["dataset_params"]["worker_per_gpu"]
)
self.dataset = OCRDatasetManager(self.params["dataset_params"])
self.dataset = OCRDatasetManager(
self.params["dataset_params"], device=self.device
)
self.dataset.load_datasets()
self.dataset.load_ddp_samplers()
self.dataset.load_dataloaders()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment