From ac7a25aff35f3da12a8012ff36c69cbda5225a05 Mon Sep 17 00:00:00 2001
From: Manon Blanco <blanco@teklia.com>
Date: Thu, 2 Nov 2023 11:34:31 +0000
Subject: [PATCH] Only predict on folders, no longer support single image

---
 dan/ocr/predict/__init__.py  |  8 +-------
 dan/ocr/predict/inference.py |  4 +---
 docs/usage/predict/index.md  | 35 +++++++++++++++++------------------
 tests/test_prediction.py     | 15 ++++++++++-----
 4 files changed, 29 insertions(+), 33 deletions(-)

diff --git a/dan/ocr/predict/__init__.py b/dan/ocr/predict/__init__.py
index 000d7c99..66e30710 100644
--- a/dan/ocr/predict/__init__.py
+++ b/dan/ocr/predict/__init__.py
@@ -17,13 +17,7 @@ def add_predict_parser(subcommands) -> None:
         help=__doc__,
     )
     # Required arguments.
-    image_or_folder_input = parser.add_mutually_exclusive_group(required=True)
-    image_or_folder_input.add_argument(
-        "--image",
-        type=pathlib.Path,
-        help="Path to the image to predict.",
-    )
-    image_or_folder_input.add_argument(
+    parser.add_argument(
         "--image-dir",
         type=pathlib.Path,
         help="Path to the folder where the images to predict are stored.",
diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py
index 7f463a33..d65931f1 100644
--- a/dan/ocr/predict/inference.py
+++ b/dan/ocr/predict/inference.py
@@ -412,7 +412,6 @@ def process_batch(
 
 
 def run(
-    image: Optional[Path],
     image_dir: Optional[Path],
     model: Path,
     output: Path,
@@ -435,7 +434,6 @@ def run(
 ) -> None:
     """
     Predict a single image save the output
-    :param image: Path to the image to predict.
     :param image_dir: Path to the folder where the images to predict are stored.
     :param model: Path to the directory containing the model, the YAML parameters file and the charset file to use for prediction.
     :param output: Path to the output folder where the results will be saved.
@@ -466,7 +464,7 @@ def run(
     # Do not use LM with invalid LM weight
     use_language_model = dan_model.lm_decoder is not None
 
-    images = image_dir.rglob(f"*{image_extension}") if not image else [image]
+    images = image_dir.rglob(f"*{image_extension}")
     for image_batch in list_to_batches(images, n=batch_size):
         process_batch(
             image_batch,
diff --git a/docs/usage/predict/index.md b/docs/usage/predict/index.md
index bfa15869..b04defa4 100644
--- a/docs/usage/predict/index.md
+++ b/docs/usage/predict/index.md
@@ -6,7 +6,6 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image.
 
 | Parameter                   | Description                                                                                                                           | Type           | Default       |
 | --------------------------- | ------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ------------- |
-| `--image`                   | Path to the image to predict. Must not be provided with `--image-dir`.                                                                | `pathlib.Path` |               |
 | `--image-dir`               | Path to the folder where the images to predict are stored. Must not be provided with `--image`.                                       | `pathlib.Path` |               |
 | `--image-extension`         | The extension of the images in the folder. Ignored if `--image-dir` is not provided.                                                  | `str`          | .jpg          |
 | `--model`                   | Path to the directory containing the model, the YAML parameters file and the charset file to use for prediction.                      | `pathlib.Path` |               |
@@ -41,13 +40,13 @@ To run a prediction with confidence scores, run this command:
 
 ```shell
 teklia-dan predict \
-    --image example.jpg \
+    --image-dir images/ \
     --model models \
     --output predict/ \
     --confidence-score
 ```
 
-It will create the following JSON file named `predict/example.json`
+It will create the following JSON file named after the image in the `predict` folder:
 
 ```json
 {
@@ -64,14 +63,14 @@ To run a prediction with confidence scores and plot line-level attention maps, r
 
 ```shell
 teklia-dan predict \
-    --image example.jpg \
+    --image-dir images/ \
     --model models \
     --output predict/ \
     --confidence-score \
     --attention-map
 ```
 
-It will create the following JSON file named `predict/example.json` and a GIF showing a line-level attention map `predict/example_line.gif`
+It will create the following JSON file named after the image and a GIF showing a line-level attention map in the `predict` folder:
 
 ```json
 {
@@ -91,7 +90,7 @@ To run a prediction with confidence scores and plot word-level attention maps, r
 
 ```shell
 teklia-dan predict \
-    --image example.jpg \
+    --image-dir images/ \
     --model models \
     --output predict/ \
     --confidence-score \
@@ -100,7 +99,7 @@ teklia-dan predict \
     --attention-map-scale 0.5
 ```
 
-It will create the following JSON file named `predict/example.json` and a GIF showing a word-level attention map `predict/example_word.gif`.
+It will create the following JSON file named after the image and a GIF showing a word-level attention map in the `predict` folder:
 
 ```json
 {
@@ -120,14 +119,14 @@ To run a prediction, plot line-level attention maps, and extract polygons, run t
 
 ```shell
 teklia-dan predict \
-    --image example.jpg \
+    --image-dir images/ \
     --model models \
     --output predict/ \
     --attention-map \
     --predict-objects
 ```
 
-It will create the following JSON file named `predict/example.json` and a GIF showing a line-level attention map with extracted polygons `predict/example_line.gif`
+It will create the following JSON file named after the image and a GIF showing a line-level attention map in the `predict` folder:
 
 ```json
 {
@@ -192,13 +191,13 @@ Then, run this command:
 
 ```shell
 teklia-dan predict \
-    --image dan_humu_page/6e830f23-e70d-4399-8b94-f36ed3198575.jpg \
+    --image-dir images/ \
     --model models \
     --use-language-model \
-    --output dan_humu_page/predict_char_lm/
+    --output predict_char_lm/
 ```
 
-It will create the following JSON file named `dan_humu_page/predict_char_lm/6e830f23-e70d-4399-8b94-f36ed3198575.json`
+It will create the following JSON file named after the image in the `predict_char_lm` folder:
 
 ```json
 {
@@ -228,13 +227,13 @@ Then, run this command:
 
 ```shell
 teklia-dan predict \
-    --image dan_humu_page/6e830f23-e70d-4399-8b94-f36ed3198575.jpg \
+    --image-dir images/ \
     --model models \
     --use-language-model \
-    --output dan_humu_page/predict_subword_lm/
+    --output predict_subword_lm
 ```
 
-It will create the following JSON file named `dan_humu_page/predict_subword_lm/6e830f23-e70d-4399-8b94-f36ed3198575.json`
+It will create the following JSON file named after the image in the `predict_subword_lm` folder:
 
 ```json
 {
@@ -264,13 +263,13 @@ Then, run this command:
 
 ```shell
 teklia-dan predict \
-    --image dan_humu_page/6e830f23-e70d-4399-8b94-f36ed3198575.jpg \
+    --image-dir images/ \
     --model models \
     --use-language-model \
-    --output dan_humu_page/predict_word_lm/
+    --output predict_word_lm/
 ```
 
-It will create the following JSON file named `dan_humu_page/predict_word_lm/6e830f23-e70d-4399-8b94-f36ed3198575.json`
+It will create the following JSON file named after the image in the `predict_word_lm` folder:
 
 ```json
 {
diff --git a/tests/test_prediction.py b/tests/test_prediction.py
index ffa7bf70..1f9fda21 100644
--- a/tests/test_prediction.py
+++ b/tests/test_prediction.py
@@ -293,9 +293,16 @@ def test_run_prediction(
     expected_prediction,
     tmp_path,
 ):
+    # Make tmpdir and copy needed image inside
+    image_dir = tmp_path / "images"
+    image_dir.mkdir()
+    shutil.copyfile(
+        (PREDICTION_DATA_PATH / "images" / image_name).with_suffix(".png"),
+        (image_dir / image_name).with_suffix(".png"),
+    )
+
     run_prediction(
-        image=(PREDICTION_DATA_PATH / "images" / image_name).with_suffix(".png"),
-        image_dir=None,
+        image_dir=image_dir,
         model=PREDICTION_DATA_PATH,
         output=tmp_path,
         confidence_score=True if confidence_score else False,
@@ -308,7 +315,7 @@ def test_run_prediction(
         temperature=temperature,
         predict_objects=False,
         max_object_height=None,
-        image_extension=None,
+        image_extension=".png",
         gpu_device=None,
         batch_size=1,
         tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"),
@@ -488,7 +495,6 @@ def test_run_prediction_batch(
         )
 
     run_prediction(
-        image=None,
         image_dir=image_dir,
         model=PREDICTION_DATA_PATH,
         output=tmp_path,
@@ -648,7 +654,6 @@ def test_run_prediction_language_model(
     yaml.dump(params, (model_path / "parameters.yml").open("w"))
 
     run_prediction(
-        image=None,
         image_dir=image_dir,
         model=model_path,
         output=tmp_path,
-- 
GitLab