From 66333f6a5f697c05464e69b4f3dd618d6cd82474 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Thu, 25 May 2023 08:58:37 +0000 Subject: [PATCH] Use device to disable pin_memory behaviour --- dan/manager/dataset.py | 10 ++++++---- dan/manager/ocr.py | 4 ++-- dan/manager/training.py | 4 +++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py index 2b28e6dc..41f9c433 100644 --- a/dan/manager/dataset.py +++ b/dan/manager/dataset.py @@ -15,12 +15,14 @@ from dan.transforms import apply_data_augmentation class DatasetManager: - def __init__(self, params): + def __init__(self, params, device: str): self.params = params self.dataset_class = None self.img_padding_value = params["config"]["padding_value"] self.my_collate_function = None + # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html + self.pin_memory = device != "cpu" self.train_dataset = None self.valid_datasets = dict() @@ -115,7 +117,7 @@ class DatasetManager: batch_sampler=self.train_sampler, sampler=self.train_sampler, num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=True, + pin_memory=self.pin_memory, collate_fn=self.my_collate_function, worker_init_fn=self.seed_worker, generator=self.generator, @@ -129,7 +131,7 @@ class DatasetManager: batch_sampler=self.valid_samplers[key], shuffle=False, num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=True, + pin_memory=self.pin_memory, drop_last=False, collate_fn=self.my_collate_function, worker_init_fn=self.seed_worker, @@ -174,7 +176,7 @@ class DatasetManager: sampler=self.test_samplers[custom_name], shuffle=False, num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=True, + pin_memory=self.pin_memory, drop_last=False, collate_fn=self.my_collate_function, worker_init_fn=self.seed_worker, diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index 45d851c3..2fd46fb8 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -27,8 +27,8 @@ class OCRDatasetManager(DatasetManager): Specific class to handle OCR/HTR tasks """ - def __init__(self, params): - super(OCRDatasetManager, self).__init__(params) + def __init__(self, params, device: str): + super(OCRDatasetManager, self).__init__(params, device) self.dataset_class = OCRDataset self.charset = ( diff --git a/dan/manager/training.py b/dan/manager/training.py index 53012a7c..52380fe8 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -116,7 +116,9 @@ class GenericTrainingManager: if "worker_per_gpu" not in self.params["dataset_params"] else self.params["dataset_params"]["worker_per_gpu"] ) - self.dataset = OCRDatasetManager(self.params["dataset_params"]) + self.dataset = OCRDatasetManager( + self.params["dataset_params"], device=self.device + ) self.dataset.load_datasets() self.dataset.load_ddp_samplers() self.dataset.load_dataloaders() -- GitLab