Skip to content
Snippets Groups Projects

Allow specifying device to use when training

Merged Manon Blanco requested to merge export-device-when-training into main
1 file
+ 1
1
Compare changes
  • Side-by-side
  • Inline
+ 10
11
@@ -100,8 +100,10 @@ class GenericTrainingManager:
self.dataset.load_dataloaders()
def init_hardware_config(self):
cuda_is_available = torch.cuda.is_available()
# Debug mode
if self.device_params["force_cpu"]:
if self.device_params["force"] not in [None, "cuda"] or not cuda_is_available:
self.device_params["use_ddp"] = False
self.device_params["use_amp"] = False
@@ -116,17 +118,14 @@ class GenericTrainingManager:
"rank": self.device_params["ddp_rank"],
}
self.is_master = self.ddp_config["master"] or not self.device_params["use_ddp"]
if self.device_params["force_cpu"]:
self.device = torch.device("cpu")
if self.device_params["use_ddp"]:
self.device = torch.device(self.ddp_config["rank"])
self.device_params["ddp_rank"] = self.ddp_config["rank"]
self.launch_ddp()
else:
if self.device_params["use_ddp"]:
self.device = torch.device(self.ddp_config["rank"])
self.device_params["ddp_rank"] = self.ddp_config["rank"]
self.launch_ddp()
else:
self.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu"
)
self.device = torch.device(
self.device_params["force"] or "cuda" if cuda_is_available else "cpu"
)
if self.device == torch.device("cpu"):
self.params["model"]["device"] = "cpu"
else:
Loading