From df229f934bc16a47e0591ee650926f870708cf49 Mon Sep 17 00:00:00 2001 From: Manon blanco <blanco@teklia.com> Date: Thu, 6 Jul 2023 15:28:40 +0000 Subject: [PATCH] Enable to use GPU for teklia_dan_predict --- dan/predict/__init__.py | 6 ++++++ dan/predict/prediction.py | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py index c38e415c..e8f3c961 100644 --- a/dan/predict/__init__.py +++ b/dan/predict/__init__.py @@ -148,4 +148,10 @@ def add_predict_parser(subcommands) -> None: type=int, default=0, ) + parser.add_argument( + "--gpu-device", + help="Use a specific GPU if available.", + type=int, + required=False, + ) parser.set_defaults(func=run) diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index db3d826f..ac029c84 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -391,6 +391,7 @@ def run( threshold_method, threshold_value, image_extension, + gpu_device, ): """ Predict a single image save the output @@ -411,13 +412,15 @@ def run( :param predict_objects: Whether to extract objects. :param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. :param threshold_value: Thresholding value to use for the "simple" thresholding method. + :param gpu_device: Use a specific GPU if available. """ # Create output directory if necessary if not os.path.exists(output): os.makedirs(output, exist_ok=True) # Load model - device = "cuda" if torch.cuda.is_available() else "cpu" + cuda_device = f":{gpu_device}" if gpu_device is not None else "" + device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" dan_model = DAN(device, temperature) dan_model.load(model, parameters, charset, mode="train") if image: -- GitLab