Skip to content
Snippets Groups Projects
Commit df229f93 authored by Manon Blanco's avatar Manon Blanco Committed by Solene Tarride
Browse files

Enable to use GPU for teklia_dan_predict

parent be01536c
No related branches found
No related tags found
1 merge request!192Enable to use GPU for teklia_dan_predict
......@@ -148,4 +148,10 @@ def add_predict_parser(subcommands) -> None:
type=int,
default=0,
)
parser.add_argument(
"--gpu-device",
help="Use a specific GPU if available.",
type=int,
required=False,
)
parser.set_defaults(func=run)
......@@ -391,6 +391,7 @@ def run(
threshold_method,
threshold_value,
image_extension,
gpu_device,
):
"""
Predict a single image save the output
......@@ -411,13 +412,15 @@ def run(
:param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
:param gpu_device: Use a specific GPU if available.
"""
# Create output directory if necessary
if not os.path.exists(output):
os.makedirs(output, exist_ok=True)
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
cuda_device = f":{gpu_device}" if gpu_device is not None else ""
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="train")
if image:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment