diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py
index 2b28e6dc1428b57a0c247ef92b8ef6acb8b32377..41f9c433c657ee5c2445d9d47250bda9a4bfa4f3 100644
--- a/dan/manager/dataset.py
+++ b/dan/manager/dataset.py
@@ -15,12 +15,14 @@ from dan.transforms import apply_data_augmentation
 
 
 class DatasetManager:
-    def __init__(self, params):
+    def __init__(self, params, device: str):
         self.params = params
         self.dataset_class = None
         self.img_padding_value = params["config"]["padding_value"]
 
         self.my_collate_function = None
+        # Whether data should be copied on GPU via https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html
+        self.pin_memory = device != "cpu"
 
         self.train_dataset = None
         self.valid_datasets = dict()
@@ -115,7 +117,7 @@ class DatasetManager:
             batch_sampler=self.train_sampler,
             sampler=self.train_sampler,
             num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-            pin_memory=True,
+            pin_memory=self.pin_memory,
             collate_fn=self.my_collate_function,
             worker_init_fn=self.seed_worker,
             generator=self.generator,
@@ -129,7 +131,7 @@ class DatasetManager:
                 batch_sampler=self.valid_samplers[key],
                 shuffle=False,
                 num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-                pin_memory=True,
+                pin_memory=self.pin_memory,
                 drop_last=False,
                 collate_fn=self.my_collate_function,
                 worker_init_fn=self.seed_worker,
@@ -174,7 +176,7 @@ class DatasetManager:
             sampler=self.test_samplers[custom_name],
             shuffle=False,
             num_workers=self.params["num_gpu"] * self.params["worker_per_gpu"],
-            pin_memory=True,
+            pin_memory=self.pin_memory,
             drop_last=False,
             collate_fn=self.my_collate_function,
             worker_init_fn=self.seed_worker,
diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index 45d851c3bb558796c29db980a8d84380d61c6908..2fd46fb8be1d30d949d0445e6b552317affe88e1 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -27,8 +27,8 @@ class OCRDatasetManager(DatasetManager):
     Specific class to handle OCR/HTR tasks
     """
 
-    def __init__(self, params):
-        super(OCRDatasetManager, self).__init__(params)
+    def __init__(self, params, device: str):
+        super(OCRDatasetManager, self).__init__(params, device)
 
         self.dataset_class = OCRDataset
         self.charset = (
diff --git a/dan/manager/training.py b/dan/manager/training.py
index 53012a7c22166f5461bc4f55d840b5b191dc8078..52380fe8b9272472a2ced5d41a26b69df604bc3c 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -116,7 +116,9 @@ class GenericTrainingManager:
             if "worker_per_gpu" not in self.params["dataset_params"]
             else self.params["dataset_params"]["worker_per_gpu"]
         )
-        self.dataset = OCRDatasetManager(self.params["dataset_params"])
+        self.dataset = OCRDatasetManager(
+            self.params["dataset_params"], device=self.device
+        )
         self.dataset.load_datasets()
         self.dataset.load_ddp_samplers()
         self.dataset.load_dataloaders()