diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py index c38e415c94469e0e0f3fc091f5151ccd2d03eda6..e8f3c961fece349e43acc15bd7f5b2caf41c60f1 100644 --- a/dan/predict/__init__.py +++ b/dan/predict/__init__.py @@ -148,4 +148,10 @@ def add_predict_parser(subcommands) -> None: type=int, default=0, ) + parser.add_argument( + "--gpu-device", + help="Use a specific GPU if available.", + type=int, + required=False, + ) parser.set_defaults(func=run) diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index db3d826f4e7fceb89b616bc65769944390711cf3..ac029c84d0891d23e24f095ef5ad2830b0567db3 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -391,6 +391,7 @@ def run( threshold_method, threshold_value, image_extension, + gpu_device, ): """ Predict a single image save the output @@ -411,13 +412,15 @@ def run( :param predict_objects: Whether to extract objects. :param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. :param threshold_value: Thresholding value to use for the "simple" thresholding method. + :param gpu_device: Use a specific GPU if available. """ # Create output directory if necessary if not os.path.exists(output): os.makedirs(output, exist_ok=True) # Load model - device = "cuda" if torch.cuda.is_available() else "cpu" + cuda_device = f":{gpu_device}" if gpu_device is not None else "" + device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" dan_model = DAN(device, temperature) dan_model.load(model, parameters, charset, mode="train") if image: