diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index b75282f48e96ec547d121eb8d9ef2d867289ba35..28183848eb6be5386b856d8fbe23c1c2e29539ba 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -101,7 +101,7 @@ class GenericTrainingManager: def init_hardware_config(self): # Debug mode - if self.device_params["force"] == "cpu": + if self.device_params["force"] not in [None, "cuda"] or not torch.cuda.is_available(): self.device_params["use_ddp"] = False self.device_params["use_amp"] = False diff --git a/dan/ocr/train.py b/dan/ocr/train.py index 03d4108bea5d23f2866594ee7eefdf20abc6b02b..d4bf423d5c52af47d1856fdbd3b3765f198888cc 100644 --- a/dan/ocr/train.py +++ b/dan/ocr/train.py @@ -146,7 +146,8 @@ def serialize_config(config): def start_training(config, mlflow_logging: bool) -> None: if ( config["training"]["device"]["use_ddp"] - and config["training"]["device"]["force"] != "cpu" + and config["training"]["device"]["force"] in [None, "cuda"] + and torch.cuda.is_available() ): mp.spawn( train_and_test,