From df229f934bc16a47e0591ee650926f870708cf49 Mon Sep 17 00:00:00 2001
From: Manon blanco <blanco@teklia.com>
Date: Thu, 6 Jul 2023 15:28:40 +0000
Subject: [PATCH] Enable to use GPU for teklia_dan_predict

---
 dan/predict/__init__.py   | 6 ++++++
 dan/predict/prediction.py | 5 ++++-
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py
index c38e415c..e8f3c961 100644
--- a/dan/predict/__init__.py
+++ b/dan/predict/__init__.py
@@ -148,4 +148,10 @@ def add_predict_parser(subcommands) -> None:
         type=int,
         default=0,
     )
+    parser.add_argument(
+        "--gpu-device",
+        help="Use a specific GPU if available.",
+        type=int,
+        required=False,
+    )
     parser.set_defaults(func=run)
diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index db3d826f..ac029c84 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -391,6 +391,7 @@ def run(
     threshold_method,
     threshold_value,
     image_extension,
+    gpu_device,
 ):
     """
     Predict a single image save the output
@@ -411,13 +412,15 @@ def run(
     :param predict_objects: Whether to extract objects.
     :param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
     :param threshold_value: Thresholding value to use for the "simple" thresholding method.
+    :param gpu_device: Use a specific GPU if available.
     """
     # Create output directory if necessary
     if not os.path.exists(output):
         os.makedirs(output, exist_ok=True)
 
     # Load model
-    device = "cuda" if torch.cuda.is_available() else "cpu"
+    cuda_device = f":{gpu_device}" if gpu_device is not None else ""
+    device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
     dan_model = DAN(device, temperature)
     dan_model.load(model, parameters, charset, mode="train")
     if image:
-- 
GitLab