From 67f28861bcada884c1c549aabd403a0452a12c89 Mon Sep 17 00:00:00 2001 From: M Generali <mgenerali@teklia.com> Date: Wed, 28 Jun 2023 11:20:35 +0000 Subject: [PATCH] Input prediction folder --- dan/predict/__init__.py | 17 +++-- dan/predict/prediction.py | 127 ++++++++++++++++++++++++++++---------- docs/usage/predict.md | 4 +- 3 files changed, 110 insertions(+), 38 deletions(-) diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py index 44537752..8d69e417 100644 --- a/dan/predict/__init__.py +++ b/dan/predict/__init__.py @@ -14,13 +14,16 @@ def add_predict_parser(subcommands) -> None: description=__doc__, help=__doc__, ) - # Required arguments. - parser.add_argument( + image_or_folder_input = parser.add_mutually_exclusive_group(required=True) + image_or_folder_input.add_argument( "--image", - type=pathlib.Path, help="Path to the image to predict.", - required=True, + ) + image_or_folder_input.add_argument( + "--image-dir", + type=pathlib.Path, + help="Path to the folder where the images to predict are stored.", ) parser.add_argument( "--model", @@ -48,6 +51,12 @@ def add_predict_parser(subcommands) -> None: required=True, ) # Optional arguments. + parser.add_argument( + "--image-extension", + type=str, + help="The extension of the images in the folder.", + default=".jpg", + ) parser.add_argument( "--scale", type=float, diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index 1c2e3954..bf82b19d 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -251,11 +251,10 @@ class DAN: return out -def run( +def process_image( image, - model, - parameters, - charset, + dan_model, + device, output, scale, confidence_score, @@ -265,40 +264,11 @@ def run( attention_map_scale, word_separators, line_separators, - temperature, image_max_width, predict_objects, threshold_method, threshold_value, ): - """ - Predict a single image save the output - :param image: Path to the image to predict. - :param model: Path to the model to use for prediction. - :param parameters: Path to the YAML parameters file. - :param charset: Path to the charset. - :param output: Path to the output folder where the results will be saved. - :param scale: Scaling factor to resize the image. - :param confidence_score: Whether to compute confidence score. - :param attention_map: Whether to plot the attention map. - :param attention_map_level: Level of objects to extract. - :param attention_map_scale: Scaling factor for the attention map. - :param word_separators: List of word separators. - :param line_separators: List of line separators. - :param image_max_width: Resize image - :param predict_objects: Whether to extract objects. - :param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. - :param threshold_value: Thresholding value to use for the "simple" thresholding method. - """ - # Create output directory if necessary - if not os.path.exists(output): - os.mkdir(output) - - # Load model - device = "cuda" if torch.cuda.is_available() else "cpu" - dan_model = DAN(device, temperature) - dan_model.load(model, parameters, charset, mode="eval") - # Load image and pre-process it if image_max_width: _, w, _ = read_image(image, scale=1).shape @@ -396,3 +366,94 @@ def run( json_filename = f"{output}/{image.stem}.json" logger.info(f"Saving JSON prediction in {json_filename}") save_json(Path(json_filename), result) + + +def run( + image, + image_dir, + model, + parameters, + charset, + output, + scale, + confidence_score, + confidence_score_levels, + attention_map, + attention_map_level, + attention_map_scale, + word_separators, + line_separators, + temperature, + image_max_width, + predict_objects, + threshold_method, + threshold_value, + image_extension, +): + """ + Predict a single image save the output + :param image: Path to the image to predict. + :param image_dir: Path to the folder where the images to predict are stored. + :param model: Path to the model to use for prediction. + :param parameters: Path to the YAML parameters file. + :param charset: Path to the charset. + :param output: Path to the output folder where the results will be saved. + :param scale: Scaling factor to resize the image. + :param confidence_score: Whether to compute confidence score. + :param attention_map: Whether to plot the attention map. + :param attention_map_level: Level of objects to extract. + :param attention_map_scale: Scaling factor for the attention map. + :param word_separators: List of word separators. + :param line_separators: List of line separators. + :param image_max_width: Resize image + :param predict_objects: Whether to extract objects. + :param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. + :param threshold_value: Thresholding value to use for the "simple" thresholding method. + """ + # Create output directory if necessary + if not os.path.exists(output): + os.makedirs(output, exist_ok=True) + + # Load model + device = "cuda" if torch.cuda.is_available() else "cpu" + dan_model = DAN(device, temperature) + dan_model.load(model, parameters, charset, mode="train") + if image: + process_image( + image, + dan_model, + device, + output, + scale, + confidence_score, + confidence_score_levels, + attention_map, + attention_map_level, + attention_map_scale, + word_separators, + line_separators, + image_max_width, + predict_objects, + threshold_method, + threshold_value, + ) + else: + for image_name in image_dir.rglob(f"*{image_extension}"): + process_image( + image_name, + dan_model, + device, + output, + scale, + confidence_score, + confidence_score_levels, + attention_map, + attention_map_level, + attention_map_scale, + word_separators, + line_separators, + image_max_width, + predict_objects, + threshold_method, + threshold_value, + ) diff --git a/docs/usage/predict.md b/docs/usage/predict.md index ea8ddd93..a15d425c 100644 --- a/docs/usage/predict.md +++ b/docs/usage/predict.md @@ -6,7 +6,9 @@ Use the `teklia-dan predict` command to predict a trained DAN model on an image. | Parameter | Description | Type | Default | | --------------------------- | -------------------------------------------------------------------------------------------- | ------- | ------------- | -| `--image` | Path to the image to predict. | `Path` | | +| `--image` | Path to the image to predict. Must not be provided with `--image-dir`. | `Path` | | +| `--image-dir` | Path to the folder where the images to predict are stored. Must not be provided with `--image`. | `Path` | | +| `--image-extension` | The extension of the images in the folder. Ignored if `--image-dir` is not provided. | `str` | .jpg | | `--model` | Path to the model to use for prediction | `Path` | | | `--parameters` | Path to the YAML parameters file. | `Path` | | | `--charset` | Path to the charset file. | `Path` | | -- GitLab