From c1ca208ffbbf319c1e101b9d1463a4759c9bd9dc Mon Sep 17 00:00:00 2001 From: manonBlanco <blanco@teklia.com> Date: Wed, 25 Oct 2023 14:05:51 +0200 Subject: [PATCH] Allow to force to a single GPU --- dan/ocr/manager/training.py | 2 +- dan/ocr/train.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index b75282f4..28183848 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -101,7 +101,7 @@ class GenericTrainingManager: def init_hardware_config(self): # Debug mode - if self.device_params["force"] == "cpu": + if self.device_params["force"] not in [None, "cuda"] or not torch.cuda.is_available(): self.device_params["use_ddp"] = False self.device_params["use_amp"] = False diff --git a/dan/ocr/train.py b/dan/ocr/train.py index 03d4108b..d4bf423d 100644 --- a/dan/ocr/train.py +++ b/dan/ocr/train.py @@ -146,7 +146,8 @@ def serialize_config(config): def start_training(config, mlflow_logging: bool) -> None: if ( config["training"]["device"]["use_ddp"] - and config["training"]["device"]["force"] != "cpu" + and config["training"]["device"]["force"] in [None, "cuda"] + and torch.cuda.is_available() ): mp.spawn( train_and_test, -- GitLab