diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py
index c38e415c94469e0e0f3fc091f5151ccd2d03eda6..e8f3c961fece349e43acc15bd7f5b2caf41c60f1 100644
--- a/dan/predict/__init__.py
+++ b/dan/predict/__init__.py
@@ -148,4 +148,10 @@ def add_predict_parser(subcommands) -> None:
         type=int,
         default=0,
     )
+    parser.add_argument(
+        "--gpu-device",
+        help="Use a specific GPU if available.",
+        type=int,
+        required=False,
+    )
     parser.set_defaults(func=run)
diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index db3d826f4e7fceb89b616bc65769944390711cf3..ac029c84d0891d23e24f095ef5ad2830b0567db3 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -391,6 +391,7 @@ def run(
     threshold_method,
     threshold_value,
     image_extension,
+    gpu_device,
 ):
     """
     Predict a single image save the output
@@ -411,13 +412,15 @@ def run(
     :param predict_objects: Whether to extract objects.
     :param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
     :param threshold_value: Thresholding value to use for the "simple" thresholding method.
+    :param gpu_device: Use a specific GPU if available.
     """
     # Create output directory if necessary
     if not os.path.exists(output):
         os.makedirs(output, exist_ok=True)
 
     # Load model
-    device = "cuda" if torch.cuda.is_available() else "cpu"
+    cuda_device = f":{gpu_device}" if gpu_device is not None else ""
+    device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
     dan_model = DAN(device, temperature)
     dan_model.load(model, parameters, charset, mode="train")
     if image: