From 2330f7ee14c21fdbfbee49771850881171ceccf6 Mon Sep 17 00:00:00 2001 From: Manon blanco <blanco@teklia.com> Date: Thu, 13 Jul 2023 12:44:07 +0000 Subject: [PATCH] Load image using torch + use training pre-processing function during prediction --- dan/manager/dataset.py | 20 ++----- dan/manager/ocr.py | 4 +- dan/predict/__init__.py | 14 ----- dan/predict/prediction.py | 80 ++++++++-------------------- dan/transforms.py | 19 +++++-- dan/utils.py | 17 +++--- docs/get_started/training.md | 6 +++ docs/usage/predict.md | 5 -- tests/data/prediction/parameters.yml | 6 ++- tests/test_prediction.py | 7 +-- 10 files changed, 60 insertions(+), 118 deletions(-) diff --git a/dan/manager/dataset.py b/dan/manager/dataset.py index 5f6ad28b..d441c6a1 100644 --- a/dan/manager/dataset.py +++ b/dan/manager/dataset.py @@ -3,12 +3,10 @@ import json import os import numpy as np -import torch from torch.utils.data import Dataset -from torchvision.io import ImageReadMode, read_image from dan.datasets.utils import natural_sort -from dan.utils import token_to_ind +from dan.utils import read_image, token_to_ind class OCRDataset(Dataset): @@ -82,14 +80,6 @@ class OCRDataset(Dataset): ) return sample - @staticmethod - def load_image(path): - """ - Load an image as a torch.Tensor and scale the values between 0 and 1. - """ - img = read_image(path, mode=ImageReadMode.RGB) - return img.to(dtype=torch.get_default_dtype()).div(255) - def load_samples(self, paths_and_sets): """ Load images and labels @@ -116,7 +106,7 @@ class OCRDataset(Dataset): ) if self.load_in_memory: samples[-1]["img"] = self.preprocessing_transforms( - self.load_image(filename) + read_image(filename) ) return samples @@ -126,10 +116,8 @@ class OCRDataset(Dataset): """ if self.load_in_memory: return self.samples[i]["img"] - else: - return self.preprocessing_transforms( - self.load_image(self.samples[i]["path"]) - ) + + return self.preprocessing_transforms(read_image(self.samples[i]["path"])) def compute_final_size(self, img): """ diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index d25535b6..fb79581f 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -49,14 +49,14 @@ class OCRDatasetManager: self.params["config"]["padding_token"] = self.tokens["pad"] self.my_collate_function = OCRCollateFunction(self.params["config"]) - self.normalization = get_normalization_transforms() + self.normalization = get_normalization_transforms(from_pil_image=True) self.augmentation = ( get_augmentation_transforms() if self.params["config"]["augmentation"] else None ) self.preprocessing = get_preprocessing_transforms( - params["config"]["preprocessings"] + params["config"]["preprocessings"], to_pil_image=True ) def load_datasets(self): diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py index fa81b61e..1b73e3b8 100644 --- a/dan/predict/__init__.py +++ b/dan/predict/__init__.py @@ -58,20 +58,6 @@ def add_predict_parser(subcommands) -> None: help="The extension of the images in the folder.", default=".jpg", ) - parser.add_argument( - "--scale", - type=float, - default=1.0, - required=False, - help="Image scaling factor before feeding it to DAN", - ) - parser.add_argument( - "--image-max-width", - type=int, - default=None, - required=False, - help="Image resizing before feeding it to DAN", - ) parser.add_argument( "--temperature", type=float, diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index 9aa1cd5f..ff87ca3e 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -5,7 +5,6 @@ import pickle from itertools import pairwise from pathlib import Path -import cv2 import numpy as np import torch import yaml @@ -20,7 +19,7 @@ from dan.predict.attention import ( plot_attention, split_text_and_confidences, ) -from dan.transforms import get_normalization_transforms +from dan.transforms import get_normalization_transforms, get_preprocessing_transforms from dan.utils import ind_to_token, read_image @@ -76,22 +75,19 @@ class DAN: self.encoder = encoder self.decoder = decoder self.normalization = get_normalization_transforms() + self.preprocessing_transforms = get_preprocessing_transforms( + parameters.get("preprocessings", []) + ) self.max_chars = parameters["max_char_prediction"] - def preprocess(self, input_image): + def preprocess(self, path): """ - Preprocess an input_image. - :param input_image: The input image to preprocess. + Preprocess an image. + :param path: Path of the image to load and preprocess. """ - assert isinstance( - input_image, np.ndarray - ), "Input image must be an np.array in RGB" - input_image = np.asarray(input_image) - if len(input_image.shape) < 3: - input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB) - - input_image = self.normalization(input_image) - return input_image + image = read_image(path) + preprocessed_image = self.preprocessing_transforms(image) + return self.normalization(preprocessed_image) def predict( self, @@ -253,11 +249,10 @@ class DAN: def process_image( - image, + image_path, dan_model, device, output, - scale, confidence_score, confidence_score_levels, attention_map, @@ -265,27 +260,18 @@ def process_image( attention_map_scale, word_separators, line_separators, - image_max_width, predict_objects, threshold_method, threshold_value, ): # Load image and pre-process it - if image_max_width: - _, w, _ = read_image(image, scale=1).shape - ratio = image_max_width / w - im = read_image(image, ratio) - else: - im = read_image(image, scale=scale) - + image = dan_model.preprocess(str(image_path)) logger.info("Image loaded.") - im_p = dan_model.preprocess(im) - logger.debug("Image pre-processed.") # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1 - input_tensor = im_p.unsqueeze(0) + input_tensor = image.unsqueeze(0) input_tensor = input_tensor.to(device) - input_sizes = [im_p.shape[1:]] + input_sizes = [image.shape[1:]] # Parse delimiters to regex word_separators = parse_delimiters(word_separators) @@ -347,11 +333,11 @@ def process_image( # Save gif with attention map if attention_map: - gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif" + gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif" logger.info(f"Creating attention GIF in {gif_filename}") # this returns polygons but unused for now. plot_attention( - image=im, + image=image, text=prediction["text"][0], weights=prediction["attentions"][0], level=attention_map_level, @@ -365,7 +351,7 @@ def process_image( ) result["attention_gif"] = gif_filename - json_filename = f"{output}/{image.stem}.json" + json_filename = f"{output}/{image_path.stem}.json" logger.info(f"Saving JSON prediction in {json_filename}") save_json(Path(json_filename), result) @@ -377,7 +363,6 @@ def run( parameters, charset, output, - scale, confidence_score, confidence_score_levels, attention_map, @@ -386,7 +371,6 @@ def run( word_separators, line_separators, temperature, - image_max_width, predict_objects, threshold_method, threshold_value, @@ -401,14 +385,12 @@ def run( :param parameters: Path to the YAML parameters file. :param charset: Path to the charset. :param output: Path to the output folder where the results will be saved. - :param scale: Scaling factor to resize the image. :param confidence_score: Whether to compute confidence score. :param attention_map: Whether to plot the attention map. :param attention_map_level: Level of objects to extract. :param attention_map_scale: Scaling factor for the attention map. :param word_separators: List of word separators. :param line_separators: List of line separators. - :param image_max_width: Resize image :param predict_objects: Whether to extract objects. :param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. :param threshold_value: Thresholding value to use for the "simple" thresholding method. @@ -423,13 +405,14 @@ def run( device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" dan_model = DAN(device, temperature) dan_model.load(model, parameters, charset, mode="eval") - if image: + + images = image_dir.rglob(f"*{image_extension}") if not image else [image] + for image_name in images: process_image( - image, + image_name, dan_model, device, output, - scale, confidence_score, confidence_score_levels, attention_map, @@ -437,28 +420,7 @@ def run( attention_map_scale, word_separators, line_separators, - image_max_width, predict_objects, threshold_method, threshold_value, ) - else: - for image_name in image_dir.rglob(f"*{image_extension}"): - process_image( - image_name, - dan_model, - device, - output, - scale, - confidence_score, - confidence_score_levels, - attention_map, - attention_map_level, - attention_map_scale, - word_separators, - line_separators, - image_max_width, - predict_objects, - threshold_method, - threshold_value, - ) diff --git a/dan/transforms.py b/dan/transforms.py index 0ad4420d..b0501e82 100644 --- a/dan/transforms.py +++ b/dan/transforms.py @@ -145,7 +145,9 @@ class ErosionDilation: return {"image": augmented_image} -def get_preprocessing_transforms(preprocessings: list) -> Compose: +def get_preprocessing_transforms( + preprocessings: list, to_pil_image: bool = False +) -> Compose: """ Returns a list of transformations to be applied to the image. """ @@ -165,7 +167,10 @@ def get_preprocessing_transforms(preprocessings: list) -> Compose: ) case Preprocessing.FixedWidthResize: transforms.append(FixedWidthResize(width=preprocessing["fixed_width"])) - transforms.append(ToPILImage()) + + if to_pil_image: + transforms.append(ToPILImage()) + return Compose(transforms) @@ -192,8 +197,14 @@ def get_augmentation_transforms() -> SomeOf: ) -def get_normalization_transforms() -> Compose: +def get_normalization_transforms(from_pil_image: bool = False) -> Compose: """ Returns a list of normalization transformations. """ - return Compose([ToTensor(), Normalize(IMAGENET_MEAN, IMAGENET_STD)]) + transforms = [] + + if from_pil_image: + transforms.append(ToTensor()) + + transforms.append(Normalize(IMAGENET_MEAN, IMAGENET_STD)) + return Compose(transforms) diff --git a/dan/utils.py b/dan/utils.py index 0cb50ef6..1ec9a8cf 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -import cv2 import torch +import torchvision.io as torchvision # Layout begin-token to end-token SEM_MATCHING_TOKENS = {"ⓘ": "â’¾", "â““": "â’¹", "â“¢": "Ⓢ", "â“’": "â’¸", "â“Ÿ": "â“…", "â“": "â’¶"} @@ -47,18 +47,13 @@ def pad_images(data): return padded_data -def read_image(filename, scale=1.0): +def read_image(path): """ - Read image and rescale it - :param filename: Image path - :param scale: Scaling factor before prediction + Read image with torch + :param path: Path of the image to load. """ - image = cv2.cvtColor(cv2.imread(str(filename)), cv2.COLOR_BGR2RGB) - if scale != 1.0: - width = int(image.shape[1] * scale) - height = int(image.shape[0] * scale) - image = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA) - return image + img = torchvision.read_image(path, mode=torchvision.ImageReadMode.RGB) + return img.to(dtype=torch.get_default_dtype()).div(255) # Charset / labels conversion diff --git a/docs/get_started/training.md b/docs/get_started/training.md index b22a5a3b..5438363e 100644 --- a/docs/get_started/training.md +++ b/docs/get_started/training.md @@ -68,5 +68,11 @@ parameters: dec_num_heads: int dec_att_dropout: float dec_res_dropout: float + preprocessings: + - type: str + max_height: int + max_width: int + fixed_height: int + fixed_width: int ``` 2. Apply a trained DAN model on an image using the [predict command](../usage/predict.md). diff --git a/docs/usage/predict.md b/docs/usage/predict.md index 93933c13..7ff55bc1 100644 --- a/docs/usage/predict.md +++ b/docs/usage/predict.md @@ -13,7 +13,6 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image. | `--parameters` | Path to the YAML parameters file. | `Path` | | | `--charset` | Path to the charset file. | `Path` | | | `--output` | Path to the output folder. Results will be saved in this directory. | `Path` | | -| `--scale` | Image scaling factor before feeding it to DAN. | `float` | `1.0` | | `--confidence-score` | Whether to return confidence scores. | `bool` | `False` | | `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str` | | | `--attention-map` | Whether to plot attention maps. | `bool` | `False` | @@ -37,7 +36,6 @@ teklia-dan predict \ --parameters dan_humu_page/parameters.yml \ --charset dan_humu_page/charset.pkl \ --output dan_humu_page/predict/ \ - --scale 0.5 \ --confidence-score ``` It will create the following JSON file named `dan_humu_page/predict/example.json` @@ -60,7 +58,6 @@ teklia-dan predict \ --parameters dan_humu_page/parameters.yml \ --charset dan_humu_page/charset.pkl \ --output dan_humu_page/predict/ \ - --scale 0.5 \ --confidence-score \ --attention-map \ ``` @@ -88,7 +85,6 @@ teklia-dan predict \ --parameters dan_humu_page/parameters.yml \ --charset dan_humu_page/charset.pkl \ --output dan_humu_page/predict/ \ - --scale 0.5 \ --confidence-score \ --attention-map \ --attention-map-level word \ @@ -118,7 +114,6 @@ teklia-dan predict \ --parameters dan_humu_page/parameters.yml \ --charset dan_humu_page/charset.pkl \ --output dan_humu_page/predict/ \ - --scale 0.5 \ --attention-map \ --predict-objects \ --threshold-method otsu diff --git a/tests/data/prediction/parameters.yml b/tests/data/prediction/parameters.yml index 76f665e2..bc56c1f6 100644 --- a/tests/data/prediction/parameters.yml +++ b/tests/data/prediction/parameters.yml @@ -1,8 +1,6 @@ --- version: 0.0.1 parameters: - mean: [166.8418783515498, 166.8418783515498, 166.8418783515498] - std: [34.084189571536385, 34.084189571536385, 34.084189571536385] max_char_prediction: 200 encoder: input_channels: 3 @@ -22,3 +20,7 @@ parameters: dec_num_heads: 4 dec_att_dropout: 0.1 dec_res_dropout: 0.1 + preprocessings: + - type: "max_resize" + max_height: 1500 + max_width: 1500 diff --git a/tests/test_prediction.py b/tests/test_prediction.py index d677d63f..db87cdb4 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -6,7 +6,6 @@ import pytest from dan.predict.prediction import DAN from dan.predict.prediction import run as run_prediction -from dan.utils import read_image @pytest.mark.parametrize( @@ -45,8 +44,8 @@ def test_predict( mode="eval", ) - image = read_image(prediction_data_path / "images" / image_name) - image = dan_model.preprocess(image) + image_path = prediction_data_path / "images" / image_name + image = dan_model.preprocess(str(image_path)) input_tensor = image.unsqueeze(0) input_tensor = input_tensor.to(device) @@ -258,7 +257,6 @@ def test_run_prediction( parameters=prediction_data_path / "parameters.yml", charset=prediction_data_path / "charset.pkl", output=tmp_path, - scale=1, confidence_score=True if confidence_score else False, confidence_score_levels=confidence_score if confidence_score else [], attention_map=False, @@ -267,7 +265,6 @@ def test_run_prediction( word_separators=[" ", "\n"], line_separators=["\n"], temperature=temperature, - image_max_width=None, predict_objects=False, threshold_method="otsu", threshold_value=0, -- GitLab