Skip to content
Snippets Groups Projects
Commit 67f28861 authored by Marie Generali's avatar Marie Generali :worried: Committed by Yoann Schneider
Browse files

Input prediction folder

parent 3a1fea73
No related branches found
No related tags found
1 merge request!176Input prediction folder
......@@ -14,13 +14,16 @@ def add_predict_parser(subcommands) -> None:
description=__doc__,
help=__doc__,
)
# Required arguments.
parser.add_argument(
image_or_folder_input = parser.add_mutually_exclusive_group(required=True)
image_or_folder_input.add_argument(
"--image",
type=pathlib.Path,
help="Path to the image to predict.",
required=True,
)
image_or_folder_input.add_argument(
"--image-dir",
type=pathlib.Path,
help="Path to the folder where the images to predict are stored.",
)
parser.add_argument(
"--model",
......@@ -48,6 +51,12 @@ def add_predict_parser(subcommands) -> None:
required=True,
)
# Optional arguments.
parser.add_argument(
"--image-extension",
type=str,
help="The extension of the images in the folder.",
default=".jpg",
)
parser.add_argument(
"--scale",
type=float,
......
......@@ -251,11 +251,10 @@ class DAN:
return out
def run(
def process_image(
image,
model,
parameters,
charset,
dan_model,
device,
output,
scale,
confidence_score,
......@@ -265,40 +264,11 @@ def run(
attention_map_scale,
word_separators,
line_separators,
temperature,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
):
"""
Predict a single image save the output
:param image: Path to the image to predict.
:param model: Path to the model to use for prediction.
:param parameters: Path to the YAML parameters file.
:param charset: Path to the charset.
:param output: Path to the output folder where the results will be saved.
:param scale: Scaling factor to resize the image.
:param confidence_score: Whether to compute confidence score.
:param attention_map: Whether to plot the attention map.
:param attention_map_level: Level of objects to extract.
:param attention_map_scale: Scaling factor for the attention map.
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param image_max_width: Resize image
:param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
"""
# Create output directory if necessary
if not os.path.exists(output):
os.mkdir(output)
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="eval")
# Load image and pre-process it
if image_max_width:
_, w, _ = read_image(image, scale=1).shape
......@@ -396,3 +366,94 @@ def run(
json_filename = f"{output}/{image.stem}.json"
logger.info(f"Saving JSON prediction in {json_filename}")
save_json(Path(json_filename), result)
def run(
image,
image_dir,
model,
parameters,
charset,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
word_separators,
line_separators,
temperature,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
image_extension,
):
"""
Predict a single image save the output
:param image: Path to the image to predict.
:param image_dir: Path to the folder where the images to predict are stored.
:param model: Path to the model to use for prediction.
:param parameters: Path to the YAML parameters file.
:param charset: Path to the charset.
:param output: Path to the output folder where the results will be saved.
:param scale: Scaling factor to resize the image.
:param confidence_score: Whether to compute confidence score.
:param attention_map: Whether to plot the attention map.
:param attention_map_level: Level of objects to extract.
:param attention_map_scale: Scaling factor for the attention map.
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param image_max_width: Resize image
:param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
"""
# Create output directory if necessary
if not os.path.exists(output):
os.makedirs(output, exist_ok=True)
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="train")
if image:
process_image(
image,
dan_model,
device,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
word_separators,
line_separators,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
)
else:
for image_name in image_dir.rglob(f"*{image_extension}"):
process_image(
image_name,
dan_model,
device,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
word_separators,
line_separators,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
)
......@@ -6,7 +6,9 @@ Use the `teklia-dan predict` command to predict a trained DAN model on an image.
| Parameter | Description | Type | Default |
| --------------------------- | -------------------------------------------------------------------------------------------- | ------- | ------------- |
| `--image` | Path to the image to predict. | `Path` | |
| `--image` | Path to the image to predict. Must not be provided with `--image-dir`. | `Path` | |
| `--image-dir` | Path to the folder where the images to predict are stored. Must not be provided with `--image`. | `Path` | |
| `--image-extension` | The extension of the images in the folder. Ignored if `--image-dir` is not provided. | `str` | .jpg |
| `--model` | Path to the model to use for prediction | `Path` | |
| `--parameters` | Path to the YAML parameters file. | `Path` | |
| `--charset` | Path to the charset file. | `Path` | |
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment