Skip to content
Snippets Groups Projects
Commit 67f28861 authored by Marie Generali's avatar Marie Generali :worried: Committed by Yoann Schneider
Browse files

Input prediction folder

parent 3a1fea73
No related branches found
No related tags found
1 merge request!176Input prediction folder
...@@ -14,13 +14,16 @@ def add_predict_parser(subcommands) -> None: ...@@ -14,13 +14,16 @@ def add_predict_parser(subcommands) -> None:
description=__doc__, description=__doc__,
help=__doc__, help=__doc__,
) )
# Required arguments. # Required arguments.
parser.add_argument( image_or_folder_input = parser.add_mutually_exclusive_group(required=True)
image_or_folder_input.add_argument(
"--image", "--image",
type=pathlib.Path,
help="Path to the image to predict.", help="Path to the image to predict.",
required=True, )
image_or_folder_input.add_argument(
"--image-dir",
type=pathlib.Path,
help="Path to the folder where the images to predict are stored.",
) )
parser.add_argument( parser.add_argument(
"--model", "--model",
...@@ -48,6 +51,12 @@ def add_predict_parser(subcommands) -> None: ...@@ -48,6 +51,12 @@ def add_predict_parser(subcommands) -> None:
required=True, required=True,
) )
# Optional arguments. # Optional arguments.
parser.add_argument(
"--image-extension",
type=str,
help="The extension of the images in the folder.",
default=".jpg",
)
parser.add_argument( parser.add_argument(
"--scale", "--scale",
type=float, type=float,
......
...@@ -251,11 +251,10 @@ class DAN: ...@@ -251,11 +251,10 @@ class DAN:
return out return out
def run( def process_image(
image, image,
model, dan_model,
parameters, device,
charset,
output, output,
scale, scale,
confidence_score, confidence_score,
...@@ -265,40 +264,11 @@ def run( ...@@ -265,40 +264,11 @@ def run(
attention_map_scale, attention_map_scale,
word_separators, word_separators,
line_separators, line_separators,
temperature,
image_max_width, image_max_width,
predict_objects, predict_objects,
threshold_method, threshold_method,
threshold_value, threshold_value,
): ):
"""
Predict a single image save the output
:param image: Path to the image to predict.
:param model: Path to the model to use for prediction.
:param parameters: Path to the YAML parameters file.
:param charset: Path to the charset.
:param output: Path to the output folder where the results will be saved.
:param scale: Scaling factor to resize the image.
:param confidence_score: Whether to compute confidence score.
:param attention_map: Whether to plot the attention map.
:param attention_map_level: Level of objects to extract.
:param attention_map_scale: Scaling factor for the attention map.
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param image_max_width: Resize image
:param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
"""
# Create output directory if necessary
if not os.path.exists(output):
os.mkdir(output)
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="eval")
# Load image and pre-process it # Load image and pre-process it
if image_max_width: if image_max_width:
_, w, _ = read_image(image, scale=1).shape _, w, _ = read_image(image, scale=1).shape
...@@ -396,3 +366,94 @@ def run( ...@@ -396,3 +366,94 @@ def run(
json_filename = f"{output}/{image.stem}.json" json_filename = f"{output}/{image.stem}.json"
logger.info(f"Saving JSON prediction in {json_filename}") logger.info(f"Saving JSON prediction in {json_filename}")
save_json(Path(json_filename), result) save_json(Path(json_filename), result)
def run(
image,
image_dir,
model,
parameters,
charset,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
word_separators,
line_separators,
temperature,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
image_extension,
):
"""
Predict a single image save the output
:param image: Path to the image to predict.
:param image_dir: Path to the folder where the images to predict are stored.
:param model: Path to the model to use for prediction.
:param parameters: Path to the YAML parameters file.
:param charset: Path to the charset.
:param output: Path to the output folder where the results will be saved.
:param scale: Scaling factor to resize the image.
:param confidence_score: Whether to compute confidence score.
:param attention_map: Whether to plot the attention map.
:param attention_map_level: Level of objects to extract.
:param attention_map_scale: Scaling factor for the attention map.
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param image_max_width: Resize image
:param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
"""
# Create output directory if necessary
if not os.path.exists(output):
os.makedirs(output, exist_ok=True)
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="train")
if image:
process_image(
image,
dan_model,
device,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
word_separators,
line_separators,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
)
else:
for image_name in image_dir.rglob(f"*{image_extension}"):
process_image(
image_name,
dan_model,
device,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
word_separators,
line_separators,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
)
...@@ -6,7 +6,9 @@ Use the `teklia-dan predict` command to predict a trained DAN model on an image. ...@@ -6,7 +6,9 @@ Use the `teklia-dan predict` command to predict a trained DAN model on an image.
| Parameter | Description | Type | Default | | Parameter | Description | Type | Default |
| --------------------------- | -------------------------------------------------------------------------------------------- | ------- | ------------- | | --------------------------- | -------------------------------------------------------------------------------------------- | ------- | ------------- |
| `--image` | Path to the image to predict. | `Path` | | | `--image` | Path to the image to predict. Must not be provided with `--image-dir`. | `Path` | |
| `--image-dir` | Path to the folder where the images to predict are stored. Must not be provided with `--image`. | `Path` | |
| `--image-extension` | The extension of the images in the folder. Ignored if `--image-dir` is not provided. | `str` | .jpg |
| `--model` | Path to the model to use for prediction | `Path` | | | `--model` | Path to the model to use for prediction | `Path` | |
| `--parameters` | Path to the YAML parameters file. | `Path` | | | `--parameters` | Path to the YAML parameters file. | `Path` | |
| `--charset` | Path to the charset file. | `Path` | | | `--charset` | Path to the charset file. | `Path` | |
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment