Allow training to be restarted on a different GPU
I started a training on GPU 1 and when I tried to restart the training on GPU 0, I got this error:
File "/gpfsdswork/projects/rech/rxm/uoh22tq/dan/dan/manager/training.py", line 226, in get_checkpoint
return torch.load(os.path.join(self.paths["checkpoints"], filename))
[...]
RuntimeError: Attempting to deserialize object on CUDA device 1 but torch.cuda.device_count() is 1. Please use torch.load with map_location to map your storages to an existing device.
According to Mélodie, we should add a map_location="correct_device"
(cuda:0
or cpu
)