From e50d239505ee62e198d5c05e8b8ea3e4501ddee3 Mon Sep 17 00:00:00 2001
From: manonBlanco <blanco@teklia.com>
Date: Mon, 30 Oct 2023 14:48:56 +0100
Subject: [PATCH] Only predict on folders, no longer support single image

---
 dan/ocr/predict/__init__.py  |  8 +-----
 dan/ocr/predict/inference.py |  4 +--
 docs/usage/predict/index.md  | 53 ++++++++++++++++++------------------
 tests/test_prediction.py     | 15 ++++++----
 4 files changed, 38 insertions(+), 42 deletions(-)

diff --git a/dan/ocr/predict/__init__.py b/dan/ocr/predict/__init__.py
index fd535905..ac79e954 100644
--- a/dan/ocr/predict/__init__.py
+++ b/dan/ocr/predict/__init__.py
@@ -17,13 +17,7 @@ def add_predict_parser(subcommands) -> None:
         help=__doc__,
     )
     # Required arguments.
-    image_or_folder_input = parser.add_mutually_exclusive_group(required=True)
-    image_or_folder_input.add_argument(
-        "--image",
-        type=pathlib.Path,
-        help="Path to the image to predict.",
-    )
-    image_or_folder_input.add_argument(
+    parser.add_argument(
         "--image-dir",
         type=pathlib.Path,
         help="Path to the folder where the images to predict are stored.",
diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py
index e012a5bf..23eb92d7 100644
--- a/dan/ocr/predict/inference.py
+++ b/dan/ocr/predict/inference.py
@@ -407,7 +407,6 @@ def process_batch(
 
 
 def run(
-    image: Optional[Path],
     image_dir: Optional[Path],
     model: Path,
     parameters: Path,
@@ -432,7 +431,6 @@ def run(
 ) -> None:
     """
     Predict a single image save the output
-    :param image: Path to the image to predict.
     :param image_dir: Path to the folder where the images to predict are stored.
     :param model: Path to the model to use for prediction.
     :param parameters: Path to the YAML parameters file.
@@ -467,7 +465,7 @@ def run(
     # Do not use LM with invalid LM weight
     use_language_model = dan_model.lm_decoder is not None
 
-    images = image_dir.rglob(f"*{image_extension}") if not image else [image]
+    images = image_dir.rglob(f"*{image_extension}")
     for image_batch in list_to_batches(images, n=batch_size):
         process_batch(
             image_batch,
diff --git a/docs/usage/predict/index.md b/docs/usage/predict/index.md
index 51f99d4a..dee5adaa 100644
--- a/docs/usage/predict/index.md
+++ b/docs/usage/predict/index.md
@@ -6,7 +6,6 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image.
 
 | Parameter                   | Description                                                                                                                           | Type           | Default       |
 | --------------------------- | ------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ------------- |
-| `--image`                   | Path to the image to predict. Must not be provided with `--image-dir`.                                                                | `pathlib.Path` |               |
 | `--image-dir`               | Path to the folder where the images to predict are stored. Must not be provided with `--image`.                                       | `pathlib.Path` |               |
 | `--image-extension`         | The extension of the images in the folder. Ignored if `--image-dir` is not provided.                                                  | `str`          | .jpg          |
 | `--model`                   | Path to the model to use for prediction                                                                                               | `pathlib.Path` |               |
@@ -37,7 +36,7 @@ To run a prediction with confidence scores, run this command:
 
 ```shell
 teklia-dan predict \
-    --image example.jpg \
+    --image-dir images/ \
     --model model.pt \
     --parameters inference_parameters.yml \
     --charset charset.pkl \
@@ -45,7 +44,7 @@ teklia-dan predict \
     --confidence-score
 ```
 
-It will create the following JSON file named `predict/example.json`
+It will create the following JSON file named from the image filename in the `predict` folder:
 
 ```json
 {
@@ -62,7 +61,7 @@ To run a prediction with confidence scores and plot line-level attention maps, r
 
 ```shell
 teklia-dan predict \
-    --image example.jpg \
+    --image-dir images/ \
     --model model.pt \
     --parameters inference_parameters.yml \
     --charset charset.pkl \
@@ -71,7 +70,7 @@ teklia-dan predict \
     --attention-map
 ```
 
-It will create the following JSON file named `predict/example.json` and a GIF showing a line-level attention map `predict/example_line.gif`
+It will create the following JSON file named from the image filename and a GIF showing a line-level attention map in the `predict` folder:
 
 ```json
 {
@@ -91,7 +90,7 @@ To run a prediction with confidence scores and plot word-level attention maps, r
 
 ```shell
 teklia-dan predict \
-    --image example.jpg \
+    --image-dir images/ \
     --model model.pt \
     --parameters inference_parameters.yml \
     --charset charset.pkl \
@@ -102,7 +101,7 @@ teklia-dan predict \
     --attention-map-scale 0.5
 ```
 
-It will create the following JSON file named `predict/example.json` and a GIF showing a word-level attention map `predict/example_word.gif`.
+It will create the following JSON file named from the image filename and a GIF showing a word-level attention map in the `predict` folder:
 
 ```json
 {
@@ -122,7 +121,7 @@ To run a prediction, plot line-level attention maps, and extract polygons, run t
 
 ```shell
 teklia-dan predict \
-    --image example.jpg \
+    --image-dir images/ \
     --model model.pt \
     --parameters inference_parameters.yml \
     --charset charset.pkl \
@@ -131,7 +130,7 @@ teklia-dan predict \
     --predict-objects
 ```
 
-It will create the following JSON file named `predict/example.json` and a GIF showing a line-level attention map with extracted polygons `predict/example_line.gif`
+It will create the following JSON file named from the image filename and a GIF showing a line-level attention map in the `predict` folder:
 
 ```json
 {
@@ -196,15 +195,15 @@ Then, run this command:
 
 ```shell
 teklia-dan predict \
-    --image dan_humu_page/6e830f23-e70d-4399-8b94-f36ed3198575.jpg \
-    --model dan_humu_page/model.pt \
-    --parameters dan_humu_page/inference_parameters_char_lm.yml \
-    --charset dan_humu_page/charset.pkl \
+    --image-dir images/ \
+    --model model.pt \
+    --parameters inference_parameters_char_lm.yml \
+    --charset charset.pkl \
     --use-language-model \
-    --output dan_humu_page/predict_char_lm/
+    --output predict_char_lm/
 ```
 
-It will create the following JSON file named `dan_humu_page/predict_char_lm/6e830f23-e70d-4399-8b94-f36ed3198575.json`
+It will create the following JSON file named from the image filename in the `predict_char_lm` folder:
 
 ```json
 {
@@ -234,15 +233,15 @@ Then, run this command:
 
 ```shell
 teklia-dan predict \
-    --image dan_humu_page/6e830f23-e70d-4399-8b94-f36ed3198575.jpg \
-    --model dan_humu_page/model.pt \
-    --parameters dan_humu_page/inference_parameters_subword_lm.yml \
-    --charset dan_humu_page/charset.pkl \
+    --image-dir images/ \
+    --model model.pt \
+    --parameters inference_parameters_subword_lm.yml \
+    --charset charset.pkl \
     --use-language-model \
-    --output dan_humu_page/predict_subword_lm/
+    --output predict_subword_lm
 ```
 
-It will create the following JSON file named `dan_humu_page/predict_subword_lm/6e830f23-e70d-4399-8b94-f36ed3198575.json`
+It will create the following JSON file named from the image filename in the `predict_subword_lm` folder:
 
 ```json
 {
@@ -272,15 +271,15 @@ Then, run this command:
 
 ```shell
 teklia-dan predict \
-    --image dan_humu_page/6e830f23-e70d-4399-8b94-f36ed3198575.jpg \
-    --model dan_humu_page/model.pt \
-    --parameters dan_humu_page/inference_parameters_word_lm.yml \
-    --charset dan_humu_page/charset.pkl \
+    --image-dir images/ \
+    --model model.pt \
+    --parameters inference_parameters_word_lm.yml \
+    --charset charset.pkl \
     --use-language-model \
-    --output dan_humu_page/predict_word_lm/
+    --output predict_word_lm/
 ```
 
-It will create the following JSON file named `dan_humu_page/predict_word_lm/6e830f23-e70d-4399-8b94-f36ed3198575.json`
+It will create the following JSON file named from the image filename in the `predict_word_lm` folder:
 
 ```json
 {
diff --git a/tests/test_prediction.py b/tests/test_prediction.py
index 6affeeb9..a42e8975 100644
--- a/tests/test_prediction.py
+++ b/tests/test_prediction.py
@@ -298,9 +298,16 @@ def test_run_prediction(
     expected_prediction,
     tmp_path,
 ):
+    # Make tmpdir and copy needed image inside
+    image_dir = tmp_path / "images"
+    image_dir.mkdir()
+    shutil.copyfile(
+        (PREDICTION_DATA_PATH / "images" / image_name).with_suffix(".png"),
+        (image_dir / image_name).with_suffix(".png"),
+    )
+
     run_prediction(
-        image=(PREDICTION_DATA_PATH / "images" / image_name).with_suffix(".png"),
-        image_dir=None,
+        image_dir=image_dir,
         model=PREDICTION_DATA_PATH / "popp_line_model.pt",
         parameters=PREDICTION_DATA_PATH / "parameters.yml",
         charset=PREDICTION_DATA_PATH / "charset.pkl",
@@ -315,7 +322,7 @@ def test_run_prediction(
         temperature=temperature,
         predict_objects=False,
         max_object_height=None,
-        image_extension=None,
+        image_extension=".png",
         gpu_device=None,
         batch_size=1,
         tokens=parse_tokens(PREDICTION_DATA_PATH / "tokens.yml"),
@@ -495,7 +502,6 @@ def test_run_prediction_batch(
         )
 
     run_prediction(
-        image=None,
         image_dir=image_dir,
         model=PREDICTION_DATA_PATH / "popp_line_model.pt",
         parameters=PREDICTION_DATA_PATH / "parameters.yml",
@@ -645,7 +651,6 @@ def test_run_prediction_language_model(
     yaml.dump(params, (tmp_path / "parameters.yml").open("w"))
 
     run_prediction(
-        image=None,
         image_dir=image_dir,
         model=PREDICTION_DATA_PATH / "popp_line_model.pt",
         parameters=tmp_path / "parameters.yml",
-- 
GitLab