Skip to content
Snippets Groups Projects
Commit dcd45f8e authored by Manon Blanco's avatar Manon Blanco Committed by Mélodie Boillet
Browse files

Load image using torch + use training pre-processing function during prediction

parent 497c45d3
No related branches found
No related tags found
No related merge requests found
...@@ -3,12 +3,10 @@ import json ...@@ -3,12 +3,10 @@ import json
import os import os
import numpy as np import numpy as np
import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision.io import ImageReadMode, read_image
from dan.datasets.utils import natural_sort from dan.datasets.utils import natural_sort
from dan.utils import token_to_ind from dan.utils import read_image, token_to_ind
class OCRDataset(Dataset): class OCRDataset(Dataset):
...@@ -82,14 +80,6 @@ class OCRDataset(Dataset): ...@@ -82,14 +80,6 @@ class OCRDataset(Dataset):
) )
return sample return sample
@staticmethod
def load_image(path):
"""
Load an image as a torch.Tensor and scale the values between 0 and 1.
"""
img = read_image(path, mode=ImageReadMode.RGB)
return img.to(dtype=torch.get_default_dtype()).div(255)
def load_samples(self, paths_and_sets): def load_samples(self, paths_and_sets):
""" """
Load images and labels Load images and labels
...@@ -116,7 +106,7 @@ class OCRDataset(Dataset): ...@@ -116,7 +106,7 @@ class OCRDataset(Dataset):
) )
if self.load_in_memory: if self.load_in_memory:
samples[-1]["img"] = self.preprocessing_transforms( samples[-1]["img"] = self.preprocessing_transforms(
self.load_image(filename) read_image(filename)
) )
return samples return samples
...@@ -126,10 +116,8 @@ class OCRDataset(Dataset): ...@@ -126,10 +116,8 @@ class OCRDataset(Dataset):
""" """
if self.load_in_memory: if self.load_in_memory:
return self.samples[i]["img"] return self.samples[i]["img"]
else:
return self.preprocessing_transforms( return self.preprocessing_transforms(read_image(self.samples[i]["path"]))
self.load_image(self.samples[i]["path"])
)
def compute_final_size(self, img): def compute_final_size(self, img):
""" """
......
...@@ -56,7 +56,7 @@ class OCRDatasetManager: ...@@ -56,7 +56,7 @@ class OCRDatasetManager:
else None else None
) )
self.preprocessing = get_preprocessing_transforms( self.preprocessing = get_preprocessing_transforms(
params["config"]["preprocessings"] params["config"]["preprocessings"], to_pil_image=True
) )
def load_datasets(self): def load_datasets(self):
......
...@@ -58,20 +58,6 @@ def add_predict_parser(subcommands) -> None: ...@@ -58,20 +58,6 @@ def add_predict_parser(subcommands) -> None:
help="The extension of the images in the folder.", help="The extension of the images in the folder.",
default=".jpg", default=".jpg",
) )
parser.add_argument(
"--scale",
type=float,
default=1.0,
required=False,
help="Image scaling factor before feeding it to DAN",
)
parser.add_argument(
"--image-max-width",
type=int,
default=None,
required=False,
help="Image resizing before feeding it to DAN",
)
parser.add_argument( parser.add_argument(
"--temperature", "--temperature",
type=float, type=float,
......
...@@ -5,7 +5,6 @@ import pickle ...@@ -5,7 +5,6 @@ import pickle
from itertools import pairwise from itertools import pairwise
from pathlib import Path from pathlib import Path
import cv2
import numpy as np import numpy as np
import torch import torch
import yaml import yaml
...@@ -20,7 +19,7 @@ from dan.predict.attention import ( ...@@ -20,7 +19,7 @@ from dan.predict.attention import (
plot_attention, plot_attention,
split_text_and_confidences, split_text_and_confidences,
) )
from dan.transforms import get_normalization_transforms from dan.transforms import get_normalization_transforms, get_preprocessing_transforms
from dan.utils import ind_to_token, read_image from dan.utils import ind_to_token, read_image
...@@ -76,22 +75,18 @@ class DAN: ...@@ -76,22 +75,18 @@ class DAN:
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.normalization = get_normalization_transforms() self.normalization = get_normalization_transforms()
self.preprocessing_transforms = get_preprocessing_transforms(
parameters.get("preprocessings", [])
)
self.max_chars = parameters["max_char_prediction"] self.max_chars = parameters["max_char_prediction"]
def preprocess(self, input_image): def preprocess(self, path):
""" """
Preprocess an input_image. Preprocess an image.
:param input_image: The input image to preprocess. :param path: Image path
""" """
assert isinstance( image = read_image(path)
input_image, np.ndarray return self.preprocessing_transforms(image)
), "Input image must be an np.array in RGB"
input_image = np.asarray(input_image)
if len(input_image.shape) < 3:
input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB)
input_image = self.normalization(input_image)
return input_image
def predict( def predict(
self, self,
...@@ -253,11 +248,10 @@ class DAN: ...@@ -253,11 +248,10 @@ class DAN:
def process_image( def process_image(
image, image_path,
dan_model, dan_model,
device, device,
output, output,
scale,
confidence_score, confidence_score,
confidence_score_levels, confidence_score_levels,
attention_map, attention_map,
...@@ -265,27 +259,18 @@ def process_image( ...@@ -265,27 +259,18 @@ def process_image(
attention_map_scale, attention_map_scale,
word_separators, word_separators,
line_separators, line_separators,
image_max_width,
predict_objects, predict_objects,
threshold_method, threshold_method,
threshold_value, threshold_value,
): ):
# Load image and pre-process it # Load image and pre-process it
if image_max_width: image = dan_model.preprocess(str(image_path))
_, w, _ = read_image(image, scale=1).shape
ratio = image_max_width / w
im = read_image(image, ratio)
else:
im = read_image(image, scale=scale)
logger.info("Image loaded.") logger.info("Image loaded.")
im_p = dan_model.preprocess(im)
logger.debug("Image pre-processed.")
# Convert to tensor of size (batch_size, channel, height, width) with batch_size=1 # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
input_tensor = im_p.unsqueeze(0) input_tensor = image.unsqueeze(0)
input_tensor = input_tensor.to(device) input_tensor = input_tensor.to(device)
input_sizes = [im_p.shape[1:]] input_sizes = [image.shape[1:]]
# Parse delimiters to regex # Parse delimiters to regex
word_separators = parse_delimiters(word_separators) word_separators = parse_delimiters(word_separators)
...@@ -347,11 +332,11 @@ def process_image( ...@@ -347,11 +332,11 @@ def process_image(
# Save gif with attention map # Save gif with attention map
if attention_map: if attention_map:
gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif" gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif"
logger.info(f"Creating attention GIF in {gif_filename}") logger.info(f"Creating attention GIF in {gif_filename}")
# this returns polygons but unused for now. # this returns polygons but unused for now.
plot_attention( plot_attention(
image=im, image=image,
text=prediction["text"][0], text=prediction["text"][0],
weights=prediction["attentions"][0], weights=prediction["attentions"][0],
level=attention_map_level, level=attention_map_level,
...@@ -365,7 +350,7 @@ def process_image( ...@@ -365,7 +350,7 @@ def process_image(
) )
result["attention_gif"] = gif_filename result["attention_gif"] = gif_filename
json_filename = f"{output}/{image.stem}.json" json_filename = f"{output}/{image_path.stem}.json"
logger.info(f"Saving JSON prediction in {json_filename}") logger.info(f"Saving JSON prediction in {json_filename}")
save_json(Path(json_filename), result) save_json(Path(json_filename), result)
...@@ -377,7 +362,6 @@ def run( ...@@ -377,7 +362,6 @@ def run(
parameters, parameters,
charset, charset,
output, output,
scale,
confidence_score, confidence_score,
confidence_score_levels, confidence_score_levels,
attention_map, attention_map,
...@@ -386,7 +370,6 @@ def run( ...@@ -386,7 +370,6 @@ def run(
word_separators, word_separators,
line_separators, line_separators,
temperature, temperature,
image_max_width,
predict_objects, predict_objects,
threshold_method, threshold_method,
threshold_value, threshold_value,
...@@ -401,14 +384,12 @@ def run( ...@@ -401,14 +384,12 @@ def run(
:param parameters: Path to the YAML parameters file. :param parameters: Path to the YAML parameters file.
:param charset: Path to the charset. :param charset: Path to the charset.
:param output: Path to the output folder where the results will be saved. :param output: Path to the output folder where the results will be saved.
:param scale: Scaling factor to resize the image.
:param confidence_score: Whether to compute confidence score. :param confidence_score: Whether to compute confidence score.
:param attention_map: Whether to plot the attention map. :param attention_map: Whether to plot the attention map.
:param attention_map_level: Level of objects to extract. :param attention_map_level: Level of objects to extract.
:param attention_map_scale: Scaling factor for the attention map. :param attention_map_scale: Scaling factor for the attention map.
:param word_separators: List of word separators. :param word_separators: List of word separators.
:param line_separators: List of line separators. :param line_separators: List of line separators.
:param image_max_width: Resize image
:param predict_objects: Whether to extract objects. :param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. :param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method. :param threshold_value: Thresholding value to use for the "simple" thresholding method.
...@@ -423,13 +404,14 @@ def run( ...@@ -423,13 +404,14 @@ def run(
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu" device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature) dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="eval") dan_model.load(model, parameters, charset, mode="eval")
if image:
images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_name in images:
process_image( process_image(
image, image_name,
dan_model, dan_model,
device, device,
output, output,
scale,
confidence_score, confidence_score,
confidence_score_levels, confidence_score_levels,
attention_map, attention_map,
...@@ -437,28 +419,7 @@ def run( ...@@ -437,28 +419,7 @@ def run(
attention_map_scale, attention_map_scale,
word_separators, word_separators,
line_separators, line_separators,
image_max_width,
predict_objects, predict_objects,
threshold_method, threshold_method,
threshold_value, threshold_value,
) )
else:
for image_name in image_dir.rglob(f"*{image_extension}"):
process_image(
image_name,
dan_model,
device,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
word_separators,
line_separators,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
)
...@@ -145,7 +145,7 @@ class ErosionDilation: ...@@ -145,7 +145,7 @@ class ErosionDilation:
return {"image": augmented_image} return {"image": augmented_image}
def get_preprocessing_transforms(preprocessings: list) -> Compose: def get_preprocessing_transforms(preprocessings: list, to_pil_image=False) -> Compose:
""" """
Returns a list of transformations to be applied to the image. Returns a list of transformations to be applied to the image.
""" """
...@@ -165,7 +165,10 @@ def get_preprocessing_transforms(preprocessings: list) -> Compose: ...@@ -165,7 +165,10 @@ def get_preprocessing_transforms(preprocessings: list) -> Compose:
) )
case Preprocessing.FixedWidthResize: case Preprocessing.FixedWidthResize:
transforms.append(FixedWidthResize(width=preprocessing["fixed_width"])) transforms.append(FixedWidthResize(width=preprocessing["fixed_width"]))
transforms.append(ToPILImage())
if to_pil_image:
transforms.append(ToPILImage())
return Compose(transforms) return Compose(transforms)
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import cv2
import torch import torch
import torchvision.io as torchvision
# Layout begin-token to end-token # Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""} SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
...@@ -47,18 +47,13 @@ def pad_images(data): ...@@ -47,18 +47,13 @@ def pad_images(data):
return padded_data return padded_data
def read_image(filename, scale=1.0): def read_image(path):
""" """
Read image and rescale it Read image with torch
:param filename: Image path :param path: Image path
:param scale: Scaling factor before prediction
""" """
image = cv2.cvtColor(cv2.imread(str(filename)), cv2.COLOR_BGR2RGB) img = torchvision.read_image(path, mode=torchvision.ImageReadMode.RGB)
if scale != 1.0: return img.to(dtype=torch.get_default_dtype()).div(255)
width = int(image.shape[1] * scale)
height = int(image.shape[0] * scale)
image = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
return image
# Charset / labels conversion # Charset / labels conversion
......
...@@ -68,5 +68,11 @@ parameters: ...@@ -68,5 +68,11 @@ parameters:
dec_num_heads: int dec_num_heads: int
dec_att_dropout: float dec_att_dropout: float
dec_res_dropout: float dec_res_dropout: float
preprocessings:
- type: str
max_height: int
max_width: int
fixed_height: int
fixed_width: int
``` ```
2. Apply a trained DAN model on an image using the [predict command](../usage/predict.md). 2. Apply a trained DAN model on an image using the [predict command](../usage/predict.md).
...@@ -13,7 +13,6 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image. ...@@ -13,7 +13,6 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image.
| `--parameters` | Path to the YAML parameters file. | `Path` | | | `--parameters` | Path to the YAML parameters file. | `Path` | |
| `--charset` | Path to the charset file. | `Path` | | | `--charset` | Path to the charset file. | `Path` | |
| `--output` | Path to the output folder. Results will be saved in this directory. | `Path` | | | `--output` | Path to the output folder. Results will be saved in this directory. | `Path` | |
| `--scale` | Image scaling factor before feeding it to DAN. | `float` | `1.0` |
| `--confidence-score` | Whether to return confidence scores. | `bool` | `False` | | `--confidence-score` | Whether to return confidence scores. | `bool` | `False` |
| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str` | | | `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str` | |
| `--attention-map` | Whether to plot attention maps. | `bool` | `False` | | `--attention-map` | Whether to plot attention maps. | `bool` | `False` |
...@@ -37,7 +36,6 @@ teklia-dan predict \ ...@@ -37,7 +36,6 @@ teklia-dan predict \
--parameters dan_humu_page/parameters.yml \ --parameters dan_humu_page/parameters.yml \
--charset dan_humu_page/charset.pkl \ --charset dan_humu_page/charset.pkl \
--output dan_humu_page/predict/ \ --output dan_humu_page/predict/ \
--scale 0.5 \
--confidence-score --confidence-score
``` ```
It will create the following JSON file named `dan_humu_page/predict/example.json` It will create the following JSON file named `dan_humu_page/predict/example.json`
...@@ -60,7 +58,6 @@ teklia-dan predict \ ...@@ -60,7 +58,6 @@ teklia-dan predict \
--parameters dan_humu_page/parameters.yml \ --parameters dan_humu_page/parameters.yml \
--charset dan_humu_page/charset.pkl \ --charset dan_humu_page/charset.pkl \
--output dan_humu_page/predict/ \ --output dan_humu_page/predict/ \
--scale 0.5 \
--confidence-score \ --confidence-score \
--attention-map \ --attention-map \
``` ```
...@@ -88,7 +85,6 @@ teklia-dan predict \ ...@@ -88,7 +85,6 @@ teklia-dan predict \
--parameters dan_humu_page/parameters.yml \ --parameters dan_humu_page/parameters.yml \
--charset dan_humu_page/charset.pkl \ --charset dan_humu_page/charset.pkl \
--output dan_humu_page/predict/ \ --output dan_humu_page/predict/ \
--scale 0.5 \
--confidence-score \ --confidence-score \
--attention-map \ --attention-map \
--attention-map-level word \ --attention-map-level word \
...@@ -118,7 +114,6 @@ teklia-dan predict \ ...@@ -118,7 +114,6 @@ teklia-dan predict \
--parameters dan_humu_page/parameters.yml \ --parameters dan_humu_page/parameters.yml \
--charset dan_humu_page/charset.pkl \ --charset dan_humu_page/charset.pkl \
--output dan_humu_page/predict/ \ --output dan_humu_page/predict/ \
--scale 0.5 \
--attention-map \ --attention-map \
--predict-objects \ --predict-objects \
--threshold-method otsu --threshold-method otsu
......
--- ---
version: 0.0.1 version: 0.0.1
parameters: parameters:
mean: [166.8418783515498, 166.8418783515498, 166.8418783515498]
std: [34.084189571536385, 34.084189571536385, 34.084189571536385]
max_char_prediction: 200 max_char_prediction: 200
encoder: encoder:
input_channels: 3 input_channels: 3
...@@ -22,3 +20,7 @@ parameters: ...@@ -22,3 +20,7 @@ parameters:
dec_num_heads: 4 dec_num_heads: 4
dec_att_dropout: 0.1 dec_att_dropout: 0.1
dec_res_dropout: 0.1 dec_res_dropout: 0.1
preprocessings:
- type: "max_resize"
max_height: 1500
max_width: 1500
...@@ -6,7 +6,6 @@ import pytest ...@@ -6,7 +6,6 @@ import pytest
from dan.predict.prediction import DAN from dan.predict.prediction import DAN
from dan.predict.prediction import run as run_prediction from dan.predict.prediction import run as run_prediction
from dan.utils import read_image
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -45,8 +44,8 @@ def test_predict( ...@@ -45,8 +44,8 @@ def test_predict(
mode="eval", mode="eval",
) )
image = read_image(prediction_data_path / "images" / image_name) image_path = prediction_data_path / "images" / image_name
image = dan_model.preprocess(image) image = dan_model.preprocess(str(image_path))
input_tensor = image.unsqueeze(0) input_tensor = image.unsqueeze(0)
input_tensor = input_tensor.to(device) input_tensor = input_tensor.to(device)
...@@ -104,7 +103,7 @@ def test_predict( ...@@ -104,7 +103,7 @@ def test_predict(
{"text": "ⓁP", "confidence": 0.94}, {"text": "ⓁP", "confidence": 0.94},
{"text": "ⒸM", "confidence": 0.93}, {"text": "ⒸM", "confidence": 0.93},
{"text": "ⓀCh", "confidence": 0.96}, {"text": "ⓀCh", "confidence": 0.96},
{"text": "ⓄPlombier", "confidence": 0.94}, {"text": "ⓄPlombier", "confidence": 0.93},
{"text": "ⓅPatron?12241", "confidence": 0.93}, {"text": "ⓅPatron?12241", "confidence": 0.93},
], ],
}, },
...@@ -258,7 +257,6 @@ def test_run_prediction( ...@@ -258,7 +257,6 @@ def test_run_prediction(
parameters=prediction_data_path / "parameters.yml", parameters=prediction_data_path / "parameters.yml",
charset=prediction_data_path / "charset.pkl", charset=prediction_data_path / "charset.pkl",
output=tmp_path, output=tmp_path,
scale=1,
confidence_score=True if confidence_score else False, confidence_score=True if confidence_score else False,
confidence_score_levels=confidence_score if confidence_score else [], confidence_score_levels=confidence_score if confidence_score else [],
attention_map=False, attention_map=False,
...@@ -267,7 +265,6 @@ def test_run_prediction( ...@@ -267,7 +265,6 @@ def test_run_prediction(
word_separators=[" ", "\n"], word_separators=[" ", "\n"],
line_separators=["\n"], line_separators=["\n"],
temperature=temperature, temperature=temperature,
image_max_width=None,
predict_objects=False, predict_objects=False,
threshold_method="otsu", threshold_method="otsu",
threshold_value=0, threshold_value=0,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment