From 67f28861bcada884c1c549aabd403a0452a12c89 Mon Sep 17 00:00:00 2001
From: M Generali <mgenerali@teklia.com>
Date: Wed, 28 Jun 2023 11:20:35 +0000
Subject: [PATCH] Input prediction folder

---
 dan/predict/__init__.py   |  17 +++--
 dan/predict/prediction.py | 127 ++++++++++++++++++++++++++++----------
 docs/usage/predict.md     |   4 +-
 3 files changed, 110 insertions(+), 38 deletions(-)

diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py
index 44537752..8d69e417 100644
--- a/dan/predict/__init__.py
+++ b/dan/predict/__init__.py
@@ -14,13 +14,16 @@ def add_predict_parser(subcommands) -> None:
         description=__doc__,
         help=__doc__,
     )
-
     # Required arguments.
-    parser.add_argument(
+    image_or_folder_input = parser.add_mutually_exclusive_group(required=True)
+    image_or_folder_input.add_argument(
         "--image",
-        type=pathlib.Path,
         help="Path to the image to predict.",
-        required=True,
+    )
+    image_or_folder_input.add_argument(
+        "--image-dir",
+        type=pathlib.Path,
+        help="Path to the folder where the images to predict are stored.",
     )
     parser.add_argument(
         "--model",
@@ -48,6 +51,12 @@ def add_predict_parser(subcommands) -> None:
         required=True,
     )
     # Optional arguments.
+    parser.add_argument(
+        "--image-extension",
+        type=str,
+        help="The extension of the images in the folder.",
+        default=".jpg",
+    )
     parser.add_argument(
         "--scale",
         type=float,
diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index 1c2e3954..bf82b19d 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -251,11 +251,10 @@ class DAN:
         return out
 
 
-def run(
+def process_image(
     image,
-    model,
-    parameters,
-    charset,
+    dan_model,
+    device,
     output,
     scale,
     confidence_score,
@@ -265,40 +264,11 @@ def run(
     attention_map_scale,
     word_separators,
     line_separators,
-    temperature,
     image_max_width,
     predict_objects,
     threshold_method,
     threshold_value,
 ):
-    """
-    Predict a single image save the output
-    :param image: Path to the image to predict.
-    :param model: Path to the model to use for prediction.
-    :param parameters: Path to the YAML parameters file.
-    :param charset: Path to the charset.
-    :param output: Path to the output folder where the results will be saved.
-    :param scale: Scaling factor to resize the image.
-    :param confidence_score: Whether to compute confidence score.
-    :param attention_map: Whether to plot the attention map.
-    :param attention_map_level: Level of objects to extract.
-    :param attention_map_scale: Scaling factor for the attention map.
-    :param word_separators: List of word separators.
-    :param line_separators: List of line separators.
-    :param image_max_width: Resize image
-    :param predict_objects: Whether to extract objects.
-    :param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
-    :param threshold_value: Thresholding value to use for the "simple" thresholding method.
-    """
-    # Create output directory if necessary
-    if not os.path.exists(output):
-        os.mkdir(output)
-
-    # Load model
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-    dan_model = DAN(device, temperature)
-    dan_model.load(model, parameters, charset, mode="eval")
-
     # Load image and pre-process it
     if image_max_width:
         _, w, _ = read_image(image, scale=1).shape
@@ -396,3 +366,94 @@ def run(
     json_filename = f"{output}/{image.stem}.json"
     logger.info(f"Saving JSON prediction in {json_filename}")
     save_json(Path(json_filename), result)
+
+
+def run(
+    image,
+    image_dir,
+    model,
+    parameters,
+    charset,
+    output,
+    scale,
+    confidence_score,
+    confidence_score_levels,
+    attention_map,
+    attention_map_level,
+    attention_map_scale,
+    word_separators,
+    line_separators,
+    temperature,
+    image_max_width,
+    predict_objects,
+    threshold_method,
+    threshold_value,
+    image_extension,
+):
+    """
+    Predict a single image save the output
+    :param image: Path to the image to predict.
+    :param image_dir: Path to the folder where the images to predict are stored.
+    :param model: Path to the model to use for prediction.
+    :param parameters: Path to the YAML parameters file.
+    :param charset: Path to the charset.
+    :param output: Path to the output folder where the results will be saved.
+    :param scale: Scaling factor to resize the image.
+    :param confidence_score: Whether to compute confidence score.
+    :param attention_map: Whether to plot the attention map.
+    :param attention_map_level: Level of objects to extract.
+    :param attention_map_scale: Scaling factor for the attention map.
+    :param word_separators: List of word separators.
+    :param line_separators: List of line separators.
+    :param image_max_width: Resize image
+    :param predict_objects: Whether to extract objects.
+    :param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
+    :param threshold_value: Thresholding value to use for the "simple" thresholding method.
+    """
+    # Create output directory if necessary
+    if not os.path.exists(output):
+        os.makedirs(output, exist_ok=True)
+
+    # Load model
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    dan_model = DAN(device, temperature)
+    dan_model.load(model, parameters, charset, mode="train")
+    if image:
+        process_image(
+            image,
+            dan_model,
+            device,
+            output,
+            scale,
+            confidence_score,
+            confidence_score_levels,
+            attention_map,
+            attention_map_level,
+            attention_map_scale,
+            word_separators,
+            line_separators,
+            image_max_width,
+            predict_objects,
+            threshold_method,
+            threshold_value,
+        )
+    else:
+        for image_name in image_dir.rglob(f"*{image_extension}"):
+            process_image(
+                image_name,
+                dan_model,
+                device,
+                output,
+                scale,
+                confidence_score,
+                confidence_score_levels,
+                attention_map,
+                attention_map_level,
+                attention_map_scale,
+                word_separators,
+                line_separators,
+                image_max_width,
+                predict_objects,
+                threshold_method,
+                threshold_value,
+            )
diff --git a/docs/usage/predict.md b/docs/usage/predict.md
index ea8ddd93..a15d425c 100644
--- a/docs/usage/predict.md
+++ b/docs/usage/predict.md
@@ -6,7 +6,9 @@ Use the `teklia-dan predict` command to predict a trained DAN model on an image.
 
 | Parameter                   | Description                                                                                  | Type    | Default       |
 | --------------------------- | -------------------------------------------------------------------------------------------- | ------- | ------------- |
-| `--image`                   | Path to the image to predict.                                                                | `Path`  |               |
+| `--image`                   | Path to the image to predict. Must not be provided with `--image-dir`.                                                                | `Path`  |               |
+| `--image-dir`                   | Path to the folder where the images to predict are stored. Must not be provided with `--image`.                                                                | `Path`  |               |
+| `--image-extension`                   | The extension of the images in the folder. Ignored if `--image-dir` is not provided.                                                                 | `str`  |      .jpg         |
 | `--model`                   | Path to the model to use for prediction                                                      | `Path`  |               |
 | `--parameters`              | Path to the YAML parameters file.                                                            | `Path`  |               |
 | `--charset`                 | Path to the charset file.                                                                    | `Path`  |               |
-- 
GitLab