diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py index 2b28e6dc1428b57a0c247ef92b8ef6acb8b32377..41f9c433c657ee5c2445d9d47250bda9a4bfa4f3 100644 --- a/dan/manager/dataset.py +++ b/dan/manager/dataset.py @@ -15,12 +15,14 @@ from dan.transforms import apply_data_augmentation class DatasetManager: - def __init__(self, params): + def __init__(self, params, device: str): self.params = params self.dataset_class = None self.img_padding_value = params["config"]["padding_value"] self.my_collate_function = None + # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html + self.pin_memory = device != "cpu" self.train_dataset = None self.valid_datasets = dict() @@ -115,7 +117,7 @@ class DatasetManager: batch_sampler=self.train_sampler, sampler=self.train_sampler, num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=True, + pin_memory=self.pin_memory, collate_fn=self.my_collate_function, worker_init_fn=self.seed_worker, generator=self.generator, @@ -129,7 +131,7 @@ class DatasetManager: batch_sampler=self.valid_samplers[key], shuffle=False, num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=True, + pin_memory=self.pin_memory, drop_last=False, collate_fn=self.my_collate_function, worker_init_fn=self.seed_worker, @@ -174,7 +176,7 @@ class DatasetManager: sampler=self.test_samplers[custom_name], shuffle=False, num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"], - pin_memory=True, + pin_memory=self.pin_memory, drop_last=False, collate_fn=self.my_collate_function, worker_init_fn=self.seed_worker, diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index 45d851c3bb558796c29db980a8d84380d61c6908..2fd46fb8be1d30d949d0445e6b552317affe88e1 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -27,8 +27,8 @@ class OCRDatasetManager(DatasetManager): Specific class to handle OCR/HTR tasks """ - def __init__(self, params): - super(OCRDatasetManager, self).__init__(params) + def __init__(self, params, device: str): + super(OCRDatasetManager, self).__init__(params, device) self.dataset_class = OCRDataset self.charset = ( diff --git a/dan/manager/training.py b/dan/manager/training.py index 53012a7c22166f5461bc4f55d840b5b191dc8078..52380fe8b9272472a2ced5d41a26b69df604bc3c 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -116,7 +116,9 @@ class GenericTrainingManager: if "worker_per_gpu" not in self.params["dataset_params"] else self.params["dataset_params"]["worker_per_gpu"] ) - self.dataset = OCRDatasetManager(self.params["dataset_params"]) + self.dataset = OCRDatasetManager( + self.params["dataset_params"], device=self.device + ) self.dataset.load_datasets() self.dataset.load_ddp_samplers() self.dataset.load_dataloaders()