Skip to content
Snippets Groups Projects
Commit dcd45f8e authored by Manon Blanco's avatar Manon Blanco Committed by Mélodie Boillet
Browse files

Load image using torch + use training pre-processing function during prediction

parent 497c45d3
No related branches found
No related tags found
No related merge requests found
......@@ -3,12 +3,10 @@ import json
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.io import ImageReadMode, read_image
from dan.datasets.utils import natural_sort
from dan.utils import token_to_ind
from dan.utils import read_image, token_to_ind
class OCRDataset(Dataset):
......@@ -82,14 +80,6 @@ class OCRDataset(Dataset):
)
return sample
@staticmethod
def load_image(path):
"""
Load an image as a torch.Tensor and scale the values between 0 and 1.
"""
img = read_image(path, mode=ImageReadMode.RGB)
return img.to(dtype=torch.get_default_dtype()).div(255)
def load_samples(self, paths_and_sets):
"""
Load images and labels
......@@ -116,7 +106,7 @@ class OCRDataset(Dataset):
)
if self.load_in_memory:
samples[-1]["img"] = self.preprocessing_transforms(
self.load_image(filename)
read_image(filename)
)
return samples
......@@ -126,10 +116,8 @@ class OCRDataset(Dataset):
"""
if self.load_in_memory:
return self.samples[i]["img"]
else:
return self.preprocessing_transforms(
self.load_image(self.samples[i]["path"])
)
return self.preprocessing_transforms(read_image(self.samples[i]["path"]))
def compute_final_size(self, img):
"""
......
......@@ -56,7 +56,7 @@ class OCRDatasetManager:
else None
)
self.preprocessing = get_preprocessing_transforms(
params["config"]["preprocessings"]
params["config"]["preprocessings"], to_pil_image=True
)
def load_datasets(self):
......
......@@ -58,20 +58,6 @@ def add_predict_parser(subcommands) -> None:
help="The extension of the images in the folder.",
default=".jpg",
)
parser.add_argument(
"--scale",
type=float,
default=1.0,
required=False,
help="Image scaling factor before feeding it to DAN",
)
parser.add_argument(
"--image-max-width",
type=int,
default=None,
required=False,
help="Image resizing before feeding it to DAN",
)
parser.add_argument(
"--temperature",
type=float,
......
......@@ -5,7 +5,6 @@ import pickle
from itertools import pairwise
from pathlib import Path
import cv2
import numpy as np
import torch
import yaml
......@@ -20,7 +19,7 @@ from dan.predict.attention import (
plot_attention,
split_text_and_confidences,
)
from dan.transforms import get_normalization_transforms
from dan.transforms import get_normalization_transforms, get_preprocessing_transforms
from dan.utils import ind_to_token, read_image
......@@ -76,22 +75,18 @@ class DAN:
self.encoder = encoder
self.decoder = decoder
self.normalization = get_normalization_transforms()
self.preprocessing_transforms = get_preprocessing_transforms(
parameters.get("preprocessings", [])
)
self.max_chars = parameters["max_char_prediction"]
def preprocess(self, input_image):
def preprocess(self, path):
"""
Preprocess an input_image.
:param input_image: The input image to preprocess.
Preprocess an image.
:param path: Image path
"""
assert isinstance(
input_image, np.ndarray
), "Input image must be an np.array in RGB"
input_image = np.asarray(input_image)
if len(input_image.shape) < 3:
input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB)
input_image = self.normalization(input_image)
return input_image
image = read_image(path)
return self.preprocessing_transforms(image)
def predict(
self,
......@@ -253,11 +248,10 @@ class DAN:
def process_image(
image,
image_path,
dan_model,
device,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
......@@ -265,27 +259,18 @@ def process_image(
attention_map_scale,
word_separators,
line_separators,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
):
# Load image and pre-process it
if image_max_width:
_, w, _ = read_image(image, scale=1).shape
ratio = image_max_width / w
im = read_image(image, ratio)
else:
im = read_image(image, scale=scale)
image = dan_model.preprocess(str(image_path))
logger.info("Image loaded.")
im_p = dan_model.preprocess(im)
logger.debug("Image pre-processed.")
# Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
input_tensor = im_p.unsqueeze(0)
input_tensor = image.unsqueeze(0)
input_tensor = input_tensor.to(device)
input_sizes = [im_p.shape[1:]]
input_sizes = [image.shape[1:]]
# Parse delimiters to regex
word_separators = parse_delimiters(word_separators)
......@@ -347,11 +332,11 @@ def process_image(
# Save gif with attention map
if attention_map:
gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif"
gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif"
logger.info(f"Creating attention GIF in {gif_filename}")
# this returns polygons but unused for now.
plot_attention(
image=im,
image=image,
text=prediction["text"][0],
weights=prediction["attentions"][0],
level=attention_map_level,
......@@ -365,7 +350,7 @@ def process_image(
)
result["attention_gif"] = gif_filename
json_filename = f"{output}/{image.stem}.json"
json_filename = f"{output}/{image_path.stem}.json"
logger.info(f"Saving JSON prediction in {json_filename}")
save_json(Path(json_filename), result)
......@@ -377,7 +362,6 @@ def run(
parameters,
charset,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
......@@ -386,7 +370,6 @@ def run(
word_separators,
line_separators,
temperature,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
......@@ -401,14 +384,12 @@ def run(
:param parameters: Path to the YAML parameters file.
:param charset: Path to the charset.
:param output: Path to the output folder where the results will be saved.
:param scale: Scaling factor to resize the image.
:param confidence_score: Whether to compute confidence score.
:param attention_map: Whether to plot the attention map.
:param attention_map_level: Level of objects to extract.
:param attention_map_scale: Scaling factor for the attention map.
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param image_max_width: Resize image
:param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
......@@ -423,13 +404,14 @@ def run(
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="eval")
if image:
images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_name in images:
process_image(
image,
image_name,
dan_model,
device,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
......@@ -437,28 +419,7 @@ def run(
attention_map_scale,
word_separators,
line_separators,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
)
else:
for image_name in image_dir.rglob(f"*{image_extension}"):
process_image(
image_name,
dan_model,
device,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
word_separators,
line_separators,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
)
......@@ -145,7 +145,7 @@ class ErosionDilation:
return {"image": augmented_image}
def get_preprocessing_transforms(preprocessings: list) -> Compose:
def get_preprocessing_transforms(preprocessings: list, to_pil_image=False) -> Compose:
"""
Returns a list of transformations to be applied to the image.
"""
......@@ -165,7 +165,10 @@ def get_preprocessing_transforms(preprocessings: list) -> Compose:
)
case Preprocessing.FixedWidthResize:
transforms.append(FixedWidthResize(width=preprocessing["fixed_width"]))
transforms.append(ToPILImage())
if to_pil_image:
transforms.append(ToPILImage())
return Compose(transforms)
......
# -*- coding: utf-8 -*-
import cv2
import torch
import torchvision.io as torchvision
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
......@@ -47,18 +47,13 @@ def pad_images(data):
return padded_data
def read_image(filename, scale=1.0):
def read_image(path):
"""
Read image and rescale it
:param filename: Image path
:param scale: Scaling factor before prediction
Read image with torch
:param path: Image path
"""
image = cv2.cvtColor(cv2.imread(str(filename)), cv2.COLOR_BGR2RGB)
if scale != 1.0:
width = int(image.shape[1] * scale)
height = int(image.shape[0] * scale)
image = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
return image
img = torchvision.read_image(path, mode=torchvision.ImageReadMode.RGB)
return img.to(dtype=torch.get_default_dtype()).div(255)
# Charset / labels conversion
......
......@@ -68,5 +68,11 @@ parameters:
dec_num_heads: int
dec_att_dropout: float
dec_res_dropout: float
preprocessings:
- type: str
max_height: int
max_width: int
fixed_height: int
fixed_width: int
```
2. Apply a trained DAN model on an image using the [predict command](../usage/predict.md).
......@@ -13,7 +13,6 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image.
| `--parameters` | Path to the YAML parameters file. | `Path` | |
| `--charset` | Path to the charset file. | `Path` | |
| `--output` | Path to the output folder. Results will be saved in this directory. | `Path` | |
| `--scale` | Image scaling factor before feeding it to DAN. | `float` | `1.0` |
| `--confidence-score` | Whether to return confidence scores. | `bool` | `False` |
| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str` | |
| `--attention-map` | Whether to plot attention maps. | `bool` | `False` |
......@@ -37,7 +36,6 @@ teklia-dan predict \
--parameters dan_humu_page/parameters.yml \
--charset dan_humu_page/charset.pkl \
--output dan_humu_page/predict/ \
--scale 0.5 \
--confidence-score
```
It will create the following JSON file named `dan_humu_page/predict/example.json`
......@@ -60,7 +58,6 @@ teklia-dan predict \
--parameters dan_humu_page/parameters.yml \
--charset dan_humu_page/charset.pkl \
--output dan_humu_page/predict/ \
--scale 0.5 \
--confidence-score \
--attention-map \
```
......@@ -88,7 +85,6 @@ teklia-dan predict \
--parameters dan_humu_page/parameters.yml \
--charset dan_humu_page/charset.pkl \
--output dan_humu_page/predict/ \
--scale 0.5 \
--confidence-score \
--attention-map \
--attention-map-level word \
......@@ -118,7 +114,6 @@ teklia-dan predict \
--parameters dan_humu_page/parameters.yml \
--charset dan_humu_page/charset.pkl \
--output dan_humu_page/predict/ \
--scale 0.5 \
--attention-map \
--predict-objects \
--threshold-method otsu
......
---
version: 0.0.1
parameters:
mean: [166.8418783515498, 166.8418783515498, 166.8418783515498]
std: [34.084189571536385, 34.084189571536385, 34.084189571536385]
max_char_prediction: 200
encoder:
input_channels: 3
......@@ -22,3 +20,7 @@ parameters:
dec_num_heads: 4
dec_att_dropout: 0.1
dec_res_dropout: 0.1
preprocessings:
- type: "max_resize"
max_height: 1500
max_width: 1500
......@@ -6,7 +6,6 @@ import pytest
from dan.predict.prediction import DAN
from dan.predict.prediction import run as run_prediction
from dan.utils import read_image
@pytest.mark.parametrize(
......@@ -45,8 +44,8 @@ def test_predict(
mode="eval",
)
image = read_image(prediction_data_path / "images" / image_name)
image = dan_model.preprocess(image)
image_path = prediction_data_path / "images" / image_name
image = dan_model.preprocess(str(image_path))
input_tensor = image.unsqueeze(0)
input_tensor = input_tensor.to(device)
......@@ -104,7 +103,7 @@ def test_predict(
{"text": "ⓁP", "confidence": 0.94},
{"text": "ⒸM", "confidence": 0.93},
{"text": "ⓀCh", "confidence": 0.96},
{"text": "ⓄPlombier", "confidence": 0.94},
{"text": "ⓄPlombier", "confidence": 0.93},
{"text": "ⓅPatron?12241", "confidence": 0.93},
],
},
......@@ -258,7 +257,6 @@ def test_run_prediction(
parameters=prediction_data_path / "parameters.yml",
charset=prediction_data_path / "charset.pkl",
output=tmp_path,
scale=1,
confidence_score=True if confidence_score else False,
confidence_score_levels=confidence_score if confidence_score else [],
attention_map=False,
......@@ -267,7 +265,6 @@ def test_run_prediction(
word_separators=[" ", "\n"],
line_separators=["\n"],
temperature=temperature,
image_max_width=None,
predict_objects=False,
threshold_method="otsu",
threshold_value=0,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment