Skip to content
Snippets Groups Projects
Commit df229f93 authored by Manon Blanco's avatar Manon Blanco Committed by Solene Tarride
Browse files

Enable to use GPU for teklia_dan_predict

parent be01536c
No related branches found
No related tags found
1 merge request!192Enable to use GPU for teklia_dan_predict
...@@ -148,4 +148,10 @@ def add_predict_parser(subcommands) -> None: ...@@ -148,4 +148,10 @@ def add_predict_parser(subcommands) -> None:
type=int, type=int,
default=0, default=0,
) )
parser.add_argument(
"--gpu-device",
help="Use a specific GPU if available.",
type=int,
required=False,
)
parser.set_defaults(func=run) parser.set_defaults(func=run)
...@@ -391,6 +391,7 @@ def run( ...@@ -391,6 +391,7 @@ def run(
threshold_method, threshold_method,
threshold_value, threshold_value,
image_extension, image_extension,
gpu_device,
): ):
""" """
Predict a single image save the output Predict a single image save the output
...@@ -411,13 +412,15 @@ def run( ...@@ -411,13 +412,15 @@ def run(
:param predict_objects: Whether to extract objects. :param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. :param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method. :param threshold_value: Thresholding value to use for the "simple" thresholding method.
:param gpu_device: Use a specific GPU if available.
""" """
# Create output directory if necessary # Create output directory if necessary
if not os.path.exists(output): if not os.path.exists(output):
os.makedirs(output, exist_ok=True) os.makedirs(output, exist_ok=True)
# Load model # Load model
device = "cuda" if torch.cuda.is_available() else "cpu" cuda_device = f":{gpu_device}" if gpu_device is not None else ""
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature) dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="train") dan_model.load(model, parameters, charset, mode="train")
if image: if image:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment