From c1ca208ffbbf319c1e101b9d1463a4759c9bd9dc Mon Sep 17 00:00:00 2001
From: manonBlanco <blanco@teklia.com>
Date: Wed, 25 Oct 2023 14:05:51 +0200
Subject: [PATCH] Allow to force to a single GPU

---
 dan/ocr/manager/training.py | 2 +-
 dan/ocr/train.py            | 3 ++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py
index b75282f4..28183848 100644
--- a/dan/ocr/manager/training.py
+++ b/dan/ocr/manager/training.py
@@ -101,7 +101,7 @@ class GenericTrainingManager:
 
     def init_hardware_config(self):
         # Debug mode
-        if self.device_params["force"] == "cpu":
+        if self.device_params["force"] not in [None, "cuda"] or not torch.cuda.is_available():
             self.device_params["use_ddp"] = False
             self.device_params["use_amp"] = False
 
diff --git a/dan/ocr/train.py b/dan/ocr/train.py
index 03d4108b..d4bf423d 100644
--- a/dan/ocr/train.py
+++ b/dan/ocr/train.py
@@ -146,7 +146,8 @@ def serialize_config(config):
 def start_training(config, mlflow_logging: bool) -> None:
     if (
         config["training"]["device"]["use_ddp"]
-        and config["training"]["device"]["force"] != "cpu"
+        and config["training"]["device"]["force"] in [None, "cuda"]
+        and torch.cuda.is_available()
     ):
         mp.spawn(
             train_and_test,
-- 
GitLab