From 2330f7ee14c21fdbfbee49771850881171ceccf6 Mon Sep 17 00:00:00 2001
From: Manon blanco <blanco@teklia.com>
Date: Thu, 13 Jul 2023 12:44:07 +0000
Subject: [PATCH] Load image using torch + use training pre-processing function
 during prediction

---
 dan/manager/dataset.py               | 20 ++-----
 dan/manager/ocr.py                   |  4 +-
 dan/predict/__init__.py              | 14 -----
 dan/predict/prediction.py            | 80 ++++++++--------------------
 dan/transforms.py                    | 19 +++++--
 dan/utils.py                         | 17 +++---
 docs/get_started/training.md         |  6 +++
 docs/usage/predict.md                |  5 --
 tests/data/prediction/parameters.yml |  6 ++-
 tests/test_prediction.py             |  7 +--
 10 files changed, 60 insertions(+), 118 deletions(-)

diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py
index 5f6ad28b..d441c6a1 100644
--- a/dan/manager/dataset.py
+++ b/dan/manager/dataset.py
@@ -3,12 +3,10 @@ import json
 import os
 
 import numpy as np
-import torch
 from torch.utils.data import Dataset
-from torchvision.io import ImageReadMode, read_image
 
 from dan.datasets.utils import natural_sort
-from dan.utils import token_to_ind
+from dan.utils import read_image, token_to_ind
 
 
 class OCRDataset(Dataset):
@@ -82,14 +80,6 @@ class OCRDataset(Dataset):
         )
         return sample
 
-    @staticmethod
-    def load_image(path):
-        """
-        Load an image as a torch.Tensor and scale the values between 0 and 1.
-        """
-        img = read_image(path, mode=ImageReadMode.RGB)
-        return img.to(dtype=torch.get_default_dtype()).div(255)
-
     def load_samples(self, paths_and_sets):
         """
         Load images and labels
@@ -116,7 +106,7 @@ class OCRDataset(Dataset):
                 )
                 if self.load_in_memory:
                     samples[-1]["img"] = self.preprocessing_transforms(
-                        self.load_image(filename)
+                        read_image(filename)
                     )
         return samples
 
@@ -126,10 +116,8 @@ class OCRDataset(Dataset):
         """
         if self.load_in_memory:
             return self.samples[i]["img"]
-        else:
-            return self.preprocessing_transforms(
-                self.load_image(self.samples[i]["path"])
-            )
+
+        return self.preprocessing_transforms(read_image(self.samples[i]["path"]))
 
     def compute_final_size(self, img):
         """
diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index d25535b6..fb79581f 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -49,14 +49,14 @@ class OCRDatasetManager:
         self.params["config"]["padding_token"] = self.tokens["pad"]
 
         self.my_collate_function = OCRCollateFunction(self.params["config"])
-        self.normalization = get_normalization_transforms()
+        self.normalization = get_normalization_transforms(from_pil_image=True)
         self.augmentation = (
             get_augmentation_transforms()
             if self.params["config"]["augmentation"]
             else None
         )
         self.preprocessing = get_preprocessing_transforms(
-            params["config"]["preprocessings"]
+            params["config"]["preprocessings"], to_pil_image=True
         )
 
     def load_datasets(self):
diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py
index fa81b61e..1b73e3b8 100644
--- a/dan/predict/__init__.py
+++ b/dan/predict/__init__.py
@@ -58,20 +58,6 @@ def add_predict_parser(subcommands) -> None:
         help="The extension of the images in the folder.",
         default=".jpg",
     )
-    parser.add_argument(
-        "--scale",
-        type=float,
-        default=1.0,
-        required=False,
-        help="Image scaling factor before feeding it to DAN",
-    )
-    parser.add_argument(
-        "--image-max-width",
-        type=int,
-        default=None,
-        required=False,
-        help="Image resizing before feeding it to DAN",
-    )
     parser.add_argument(
         "--temperature",
         type=float,
diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index 9aa1cd5f..ff87ca3e 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -5,7 +5,6 @@ import pickle
 from itertools import pairwise
 from pathlib import Path
 
-import cv2
 import numpy as np
 import torch
 import yaml
@@ -20,7 +19,7 @@ from dan.predict.attention import (
     plot_attention,
     split_text_and_confidences,
 )
-from dan.transforms import get_normalization_transforms
+from dan.transforms import get_normalization_transforms, get_preprocessing_transforms
 from dan.utils import ind_to_token, read_image
 
 
@@ -76,22 +75,19 @@ class DAN:
         self.encoder = encoder
         self.decoder = decoder
         self.normalization = get_normalization_transforms()
+        self.preprocessing_transforms = get_preprocessing_transforms(
+            parameters.get("preprocessings", [])
+        )
         self.max_chars = parameters["max_char_prediction"]
 
-    def preprocess(self, input_image):
+    def preprocess(self, path):
         """
-        Preprocess an input_image.
-        :param input_image: The input image to preprocess.
+        Preprocess an image.
+        :param path: Path of the image to load and preprocess.
         """
-        assert isinstance(
-            input_image, np.ndarray
-        ), "Input image must be an np.array in RGB"
-        input_image = np.asarray(input_image)
-        if len(input_image.shape) < 3:
-            input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB)
-
-        input_image = self.normalization(input_image)
-        return input_image
+        image = read_image(path)
+        preprocessed_image = self.preprocessing_transforms(image)
+        return self.normalization(preprocessed_image)
 
     def predict(
         self,
@@ -253,11 +249,10 @@ class DAN:
 
 
 def process_image(
-    image,
+    image_path,
     dan_model,
     device,
     output,
-    scale,
     confidence_score,
     confidence_score_levels,
     attention_map,
@@ -265,27 +260,18 @@ def process_image(
     attention_map_scale,
     word_separators,
     line_separators,
-    image_max_width,
     predict_objects,
     threshold_method,
     threshold_value,
 ):
     # Load image and pre-process it
-    if image_max_width:
-        _, w, _ = read_image(image, scale=1).shape
-        ratio = image_max_width / w
-        im = read_image(image, ratio)
-    else:
-        im = read_image(image, scale=scale)
-
+    image = dan_model.preprocess(str(image_path))
     logger.info("Image loaded.")
-    im_p = dan_model.preprocess(im)
-    logger.debug("Image pre-processed.")
 
     # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
-    input_tensor = im_p.unsqueeze(0)
+    input_tensor = image.unsqueeze(0)
     input_tensor = input_tensor.to(device)
-    input_sizes = [im_p.shape[1:]]
+    input_sizes = [image.shape[1:]]
 
     # Parse delimiters to regex
     word_separators = parse_delimiters(word_separators)
@@ -347,11 +333,11 @@ def process_image(
 
     # Save gif with attention map
     if attention_map:
-        gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif"
+        gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif"
         logger.info(f"Creating attention GIF in {gif_filename}")
         # this returns polygons but unused for now.
         plot_attention(
-            image=im,
+            image=image,
             text=prediction["text"][0],
             weights=prediction["attentions"][0],
             level=attention_map_level,
@@ -365,7 +351,7 @@ def process_image(
         )
         result["attention_gif"] = gif_filename
 
-    json_filename = f"{output}/{image.stem}.json"
+    json_filename = f"{output}/{image_path.stem}.json"
     logger.info(f"Saving JSON prediction in {json_filename}")
     save_json(Path(json_filename), result)
 
@@ -377,7 +363,6 @@ def run(
     parameters,
     charset,
     output,
-    scale,
     confidence_score,
     confidence_score_levels,
     attention_map,
@@ -386,7 +371,6 @@ def run(
     word_separators,
     line_separators,
     temperature,
-    image_max_width,
     predict_objects,
     threshold_method,
     threshold_value,
@@ -401,14 +385,12 @@ def run(
     :param parameters: Path to the YAML parameters file.
     :param charset: Path to the charset.
     :param output: Path to the output folder where the results will be saved.
-    :param scale: Scaling factor to resize the image.
     :param confidence_score: Whether to compute confidence score.
     :param attention_map: Whether to plot the attention map.
     :param attention_map_level: Level of objects to extract.
     :param attention_map_scale: Scaling factor for the attention map.
     :param word_separators: List of word separators.
     :param line_separators: List of line separators.
-    :param image_max_width: Resize image
     :param predict_objects: Whether to extract objects.
     :param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
     :param threshold_value: Thresholding value to use for the "simple" thresholding method.
@@ -423,13 +405,14 @@ def run(
     device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
     dan_model = DAN(device, temperature)
     dan_model.load(model, parameters, charset, mode="eval")
-    if image:
+
+    images = image_dir.rglob(f"*{image_extension}") if not image else [image]
+    for image_name in images:
         process_image(
-            image,
+            image_name,
             dan_model,
             device,
             output,
-            scale,
             confidence_score,
             confidence_score_levels,
             attention_map,
@@ -437,28 +420,7 @@ def run(
             attention_map_scale,
             word_separators,
             line_separators,
-            image_max_width,
             predict_objects,
             threshold_method,
             threshold_value,
         )
-    else:
-        for image_name in image_dir.rglob(f"*{image_extension}"):
-            process_image(
-                image_name,
-                dan_model,
-                device,
-                output,
-                scale,
-                confidence_score,
-                confidence_score_levels,
-                attention_map,
-                attention_map_level,
-                attention_map_scale,
-                word_separators,
-                line_separators,
-                image_max_width,
-                predict_objects,
-                threshold_method,
-                threshold_value,
-            )
diff --git a/dan/transforms.py b/dan/transforms.py
index 0ad4420d..b0501e82 100644
--- a/dan/transforms.py
+++ b/dan/transforms.py
@@ -145,7 +145,9 @@ class ErosionDilation:
         return {"image": augmented_image}
 
 
-def get_preprocessing_transforms(preprocessings: list) -> Compose:
+def get_preprocessing_transforms(
+    preprocessings: list, to_pil_image: bool = False
+) -> Compose:
     """
     Returns a list of transformations to be applied to the image.
     """
@@ -165,7 +167,10 @@ def get_preprocessing_transforms(preprocessings: list) -> Compose:
                 )
             case Preprocessing.FixedWidthResize:
                 transforms.append(FixedWidthResize(width=preprocessing["fixed_width"]))
-    transforms.append(ToPILImage())
+
+    if to_pil_image:
+        transforms.append(ToPILImage())
+
     return Compose(transforms)
 
 
@@ -192,8 +197,14 @@ def get_augmentation_transforms() -> SomeOf:
     )
 
 
-def get_normalization_transforms() -> Compose:
+def get_normalization_transforms(from_pil_image: bool = False) -> Compose:
     """
     Returns a list of normalization transformations.
     """
-    return Compose([ToTensor(), Normalize(IMAGENET_MEAN, IMAGENET_STD)])
+    transforms = []
+
+    if from_pil_image:
+        transforms.append(ToTensor())
+
+    transforms.append(Normalize(IMAGENET_MEAN, IMAGENET_STD))
+    return Compose(transforms)
diff --git a/dan/utils.py b/dan/utils.py
index 0cb50ef6..1ec9a8cf 100644
--- a/dan/utils.py
+++ b/dan/utils.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
-import cv2
 import torch
+import torchvision.io as torchvision
 
 # Layout begin-token to end-token
 SEM_MATCHING_TOKENS = {"ⓘ": "Ⓘ", "ⓓ": "Ⓓ", "ⓢ": "Ⓢ", "ⓒ": "Ⓒ", "ⓟ": "Ⓟ", "ⓐ": "Ⓐ"}
@@ -47,18 +47,13 @@ def pad_images(data):
     return padded_data
 
 
-def read_image(filename, scale=1.0):
+def read_image(path):
     """
-    Read image and rescale it
-    :param filename: Image path
-    :param scale: Scaling factor before prediction
+    Read image with torch
+    :param path: Path of the image to load.
     """
-    image = cv2.cvtColor(cv2.imread(str(filename)), cv2.COLOR_BGR2RGB)
-    if scale != 1.0:
-        width = int(image.shape[1] * scale)
-        height = int(image.shape[0] * scale)
-        image = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
-    return image
+    img = torchvision.read_image(path, mode=torchvision.ImageReadMode.RGB)
+    return img.to(dtype=torch.get_default_dtype()).div(255)
 
 
 # Charset / labels conversion
diff --git a/docs/get_started/training.md b/docs/get_started/training.md
index b22a5a3b..5438363e 100644
--- a/docs/get_started/training.md
+++ b/docs/get_started/training.md
@@ -68,5 +68,11 @@ parameters:
     dec_num_heads: int
     dec_att_dropout: float
     dec_res_dropout: float
+  preprocessings:
+    - type: str
+      max_height: int
+      max_width: int
+      fixed_height: int
+      fixed_width: int
 ```
 2. Apply a trained DAN model on an image using the [predict command](../usage/predict.md).
diff --git a/docs/usage/predict.md b/docs/usage/predict.md
index 93933c13..7ff55bc1 100644
--- a/docs/usage/predict.md
+++ b/docs/usage/predict.md
@@ -13,7 +13,6 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image.
 | `--parameters`              | Path to the YAML parameters file.                                                            | `Path`  |               |
 | `--charset`                 | Path to the charset file.                                                                    | `Path`  |               |
 | `--output`                  | Path to the output folder. Results will be saved in this directory.                          | `Path`  |               |
-| `--scale`                   | Image scaling factor before feeding it to DAN.                                               | `float` | `1.0`         |
 | `--confidence-score`        | Whether to return confidence scores.                                                         | `bool`  | `False`       |
 | `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`.  | `str`   |               |
 | `--attention-map`           | Whether to plot attention maps.                                                              | `bool`  | `False`       |
@@ -37,7 +36,6 @@ teklia-dan predict \
     --parameters dan_humu_page/parameters.yml \
     --charset dan_humu_page/charset.pkl \
     --output dan_humu_page/predict/ \
-    --scale 0.5 \
     --confidence-score
 ```
 It will create the following JSON file named `dan_humu_page/predict/example.json`
@@ -60,7 +58,6 @@ teklia-dan predict \
     --parameters dan_humu_page/parameters.yml \
     --charset dan_humu_page/charset.pkl \
     --output dan_humu_page/predict/ \
-    --scale 0.5 \
     --confidence-score \
     --attention-map \
 ```
@@ -88,7 +85,6 @@ teklia-dan predict \
     --parameters dan_humu_page/parameters.yml \
     --charset dan_humu_page/charset.pkl \
     --output dan_humu_page/predict/ \
-    --scale 0.5 \
     --confidence-score \
     --attention-map \
     --attention-map-level word \
@@ -118,7 +114,6 @@ teklia-dan predict \
     --parameters dan_humu_page/parameters.yml \
     --charset dan_humu_page/charset.pkl \
     --output dan_humu_page/predict/ \
-    --scale 0.5 \
     --attention-map \
     --predict-objects \
     --threshold-method otsu
diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml
index 76f665e2..bc56c1f6 100644
--- a/tests/data/prediction/parameters.yml
+++ b/tests/data/prediction/parameters.yml
@@ -1,8 +1,6 @@
 ---
 version: 0.0.1
 parameters:
-  mean: [166.8418783515498, 166.8418783515498, 166.8418783515498]
-  std: [34.084189571536385, 34.084189571536385, 34.084189571536385]
   max_char_prediction: 200
   encoder:
     input_channels: 3
@@ -22,3 +20,7 @@ parameters:
     dec_num_heads: 4
     dec_att_dropout: 0.1
     dec_res_dropout: 0.1
+  preprocessings:
+    - type: "max_resize"
+      max_height: 1500
+      max_width: 1500
diff --git a/tests/test_prediction.py b/tests/test_prediction.py
index d677d63f..db87cdb4 100644
--- a/tests/test_prediction.py
+++ b/tests/test_prediction.py
@@ -6,7 +6,6 @@ import pytest
 
 from dan.predict.prediction import DAN
 from dan.predict.prediction import run as run_prediction
-from dan.utils import read_image
 
 
 @pytest.mark.parametrize(
@@ -45,8 +44,8 @@ def test_predict(
         mode="eval",
     )
 
-    image = read_image(prediction_data_path / "images" / image_name)
-    image = dan_model.preprocess(image)
+    image_path = prediction_data_path / "images" / image_name
+    image = dan_model.preprocess(str(image_path))
 
     input_tensor = image.unsqueeze(0)
     input_tensor = input_tensor.to(device)
@@ -258,7 +257,6 @@ def test_run_prediction(
         parameters=prediction_data_path / "parameters.yml",
         charset=prediction_data_path / "charset.pkl",
         output=tmp_path,
-        scale=1,
         confidence_score=True if confidence_score else False,
         confidence_score_levels=confidence_score if confidence_score else [],
         attention_map=False,
@@ -267,7 +265,6 @@ def test_run_prediction(
         word_separators=[" ", "\n"],
         line_separators=["\n"],
         temperature=temperature,
-        image_max_width=None,
         predict_objects=False,
         threshold_method="otsu",
         threshold_value=0,
-- 
GitLab