Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (2)
......@@ -3,12 +3,10 @@ import json
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.io import ImageReadMode, read_image
from dan.datasets.utils import natural_sort
from dan.utils import token_to_ind
from dan.utils import read_image, token_to_ind
class OCRDataset(Dataset):
......@@ -82,14 +80,6 @@ class OCRDataset(Dataset):
)
return sample
@staticmethod
def load_image(path):
"""
Load an image as a torch.Tensor and scale the values between 0 and 1.
"""
img = read_image(path, mode=ImageReadMode.RGB)
return img.to(dtype=torch.get_default_dtype()).div(255)
def load_samples(self, paths_and_sets):
"""
Load images and labels
......@@ -116,7 +106,7 @@ class OCRDataset(Dataset):
)
if self.load_in_memory:
samples[-1]["img"] = self.preprocessing_transforms(
self.load_image(filename)
read_image(filename)
)
return samples
......@@ -126,10 +116,8 @@ class OCRDataset(Dataset):
"""
if self.load_in_memory:
return self.samples[i]["img"]
else:
return self.preprocessing_transforms(
self.load_image(self.samples[i]["path"])
)
return self.preprocessing_transforms(read_image(self.samples[i]["path"]))
def compute_final_size(self, img):
"""
......
......@@ -49,14 +49,14 @@ class OCRDatasetManager:
self.params["config"]["padding_token"] = self.tokens["pad"]
self.my_collate_function = OCRCollateFunction(self.params["config"])
self.normalization = get_normalization_transforms()
self.normalization = get_normalization_transforms(from_pil_image=True)
self.augmentation = (
get_augmentation_transforms()
if self.params["config"]["augmentation"]
else None
)
self.preprocessing = get_preprocessing_transforms(
params["config"]["preprocessings"]
params["config"]["preprocessings"], to_pil_image=True
)
def load_datasets(self):
......
......@@ -58,20 +58,6 @@ def add_predict_parser(subcommands) -> None:
help="The extension of the images in the folder.",
default=".jpg",
)
parser.add_argument(
"--scale",
type=float,
default=1.0,
required=False,
help="Image scaling factor before feeding it to DAN",
)
parser.add_argument(
"--image-max-width",
type=int,
default=None,
required=False,
help="Image resizing before feeding it to DAN",
)
parser.add_argument(
"--temperature",
type=float,
......
......@@ -5,7 +5,6 @@ import pickle
from itertools import pairwise
from pathlib import Path
import cv2
import numpy as np
import torch
import yaml
......@@ -20,7 +19,7 @@ from dan.predict.attention import (
plot_attention,
split_text_and_confidences,
)
from dan.transforms import get_normalization_transforms
from dan.transforms import get_normalization_transforms, get_preprocessing_transforms
from dan.utils import ind_to_token, read_image
......@@ -76,22 +75,19 @@ class DAN:
self.encoder = encoder
self.decoder = decoder
self.normalization = get_normalization_transforms()
self.preprocessing_transforms = get_preprocessing_transforms(
parameters.get("preprocessings", [])
)
self.max_chars = parameters["max_char_prediction"]
def preprocess(self, input_image):
def preprocess(self, path):
"""
Preprocess an input_image.
:param input_image: The input image to preprocess.
Preprocess an image.
:param path: Path of the image to load and preprocess.
"""
assert isinstance(
input_image, np.ndarray
), "Input image must be an np.array in RGB"
input_image = np.asarray(input_image)
if len(input_image.shape) < 3:
input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB)
input_image = self.normalization(input_image)
return input_image
image = read_image(path)
preprocessed_image = self.preprocessing_transforms(image)
return self.normalization(preprocessed_image)
def predict(
self,
......@@ -253,11 +249,10 @@ class DAN:
def process_image(
image,
image_path,
dan_model,
device,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
......@@ -265,27 +260,18 @@ def process_image(
attention_map_scale,
word_separators,
line_separators,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
):
# Load image and pre-process it
if image_max_width:
_, w, _ = read_image(image, scale=1).shape
ratio = image_max_width / w
im = read_image(image, ratio)
else:
im = read_image(image, scale=scale)
image = dan_model.preprocess(str(image_path))
logger.info("Image loaded.")
im_p = dan_model.preprocess(im)
logger.debug("Image pre-processed.")
# Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
input_tensor = im_p.unsqueeze(0)
input_tensor = image.unsqueeze(0)
input_tensor = input_tensor.to(device)
input_sizes = [im_p.shape[1:]]
input_sizes = [image.shape[1:]]
# Parse delimiters to regex
word_separators = parse_delimiters(word_separators)
......@@ -347,11 +333,11 @@ def process_image(
# Save gif with attention map
if attention_map:
gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif"
gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif"
logger.info(f"Creating attention GIF in {gif_filename}")
# this returns polygons but unused for now.
plot_attention(
image=im,
image=image,
text=prediction["text"][0],
weights=prediction["attentions"][0],
level=attention_map_level,
......@@ -365,7 +351,7 @@ def process_image(
)
result["attention_gif"] = gif_filename
json_filename = f"{output}/{image.stem}.json"
json_filename = f"{output}/{image_path.stem}.json"
logger.info(f"Saving JSON prediction in {json_filename}")
save_json(Path(json_filename), result)
......@@ -377,7 +363,6 @@ def run(
parameters,
charset,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
......@@ -386,7 +371,6 @@ def run(
word_separators,
line_separators,
temperature,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
......@@ -401,14 +385,12 @@ def run(
:param parameters: Path to the YAML parameters file.
:param charset: Path to the charset.
:param output: Path to the output folder where the results will be saved.
:param scale: Scaling factor to resize the image.
:param confidence_score: Whether to compute confidence score.
:param attention_map: Whether to plot the attention map.
:param attention_map_level: Level of objects to extract.
:param attention_map_scale: Scaling factor for the attention map.
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param image_max_width: Resize image
:param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
......@@ -423,13 +405,14 @@ def run(
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="eval")
if image:
images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_name in images:
process_image(
image,
image_name,
dan_model,
device,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
......@@ -437,28 +420,7 @@ def run(
attention_map_scale,
word_separators,
line_separators,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
)
else:
for image_name in image_dir.rglob(f"*{image_extension}"):
process_image(
image_name,
dan_model,
device,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
word_separators,
line_separators,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
)
......@@ -145,7 +145,9 @@ class ErosionDilation:
return {"image": augmented_image}
def get_preprocessing_transforms(preprocessings: list) -> Compose:
def get_preprocessing_transforms(
preprocessings: list, to_pil_image: bool = False
) -> Compose:
"""
Returns a list of transformations to be applied to the image.
"""
......@@ -165,7 +167,10 @@ def get_preprocessing_transforms(preprocessings: list) -> Compose:
)
case Preprocessing.FixedWidthResize:
transforms.append(FixedWidthResize(width=preprocessing["fixed_width"]))
transforms.append(ToPILImage())
if to_pil_image:
transforms.append(ToPILImage())
return Compose(transforms)
......@@ -192,8 +197,14 @@ def get_augmentation_transforms() -> SomeOf:
)
def get_normalization_transforms() -> Compose:
def get_normalization_transforms(from_pil_image: bool = False) -> Compose:
"""
Returns a list of normalization transformations.
"""
return Compose([ToTensor(), Normalize(IMAGENET_MEAN, IMAGENET_STD)])
transforms = []
if from_pil_image:
transforms.append(ToTensor())
transforms.append(Normalize(IMAGENET_MEAN, IMAGENET_STD))
return Compose(transforms)
# -*- coding: utf-8 -*-
import cv2
import torch
import torchvision.io as torchvision
# Layout begin-token to end-token
SEM_MATCHING_TOKENS = {"": "", "": "", "": "", "": "", "": "", "": ""}
......@@ -47,18 +47,13 @@ def pad_images(data):
return padded_data
def read_image(filename, scale=1.0):
def read_image(path):
"""
Read image and rescale it
:param filename: Image path
:param scale: Scaling factor before prediction
Read image with torch
:param path: Path of the image to load.
"""
image = cv2.cvtColor(cv2.imread(str(filename)), cv2.COLOR_BGR2RGB)
if scale != 1.0:
width = int(image.shape[1] * scale)
height = int(image.shape[0] * scale)
image = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
return image
img = torchvision.read_image(path, mode=torchvision.ImageReadMode.RGB)
return img.to(dtype=torch.get_default_dtype()).div(255)
# Charset / labels conversion
......
black==23.3.0
doc8==1.1.1
# Pick a specific version because griffe==0.32.0 introduces a bug
griffe==0.31.0
mkdocs==1.4.2
mkdocs-material==9.1.9
mkdocstrings==0.20.0
......
......@@ -68,5 +68,11 @@ parameters:
dec_num_heads: int
dec_att_dropout: float
dec_res_dropout: float
preprocessings:
- type: str
max_height: int
max_width: int
fixed_height: int
fixed_width: int
```
2. Apply a trained DAN model on an image using the [predict command](../usage/predict.md).
......@@ -13,7 +13,6 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image.
| `--parameters` | Path to the YAML parameters file. | `Path` | |
| `--charset` | Path to the charset file. | `Path` | |
| `--output` | Path to the output folder. Results will be saved in this directory. | `Path` | |
| `--scale` | Image scaling factor before feeding it to DAN. | `float` | `1.0` |
| `--confidence-score` | Whether to return confidence scores. | `bool` | `False` |
| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str` | |
| `--attention-map` | Whether to plot attention maps. | `bool` | `False` |
......@@ -37,7 +36,6 @@ teklia-dan predict \
--parameters dan_humu_page/parameters.yml \
--charset dan_humu_page/charset.pkl \
--output dan_humu_page/predict/ \
--scale 0.5 \
--confidence-score
```
It will create the following JSON file named `dan_humu_page/predict/example.json`
......@@ -60,7 +58,6 @@ teklia-dan predict \
--parameters dan_humu_page/parameters.yml \
--charset dan_humu_page/charset.pkl \
--output dan_humu_page/predict/ \
--scale 0.5 \
--confidence-score \
--attention-map \
```
......@@ -88,7 +85,6 @@ teklia-dan predict \
--parameters dan_humu_page/parameters.yml \
--charset dan_humu_page/charset.pkl \
--output dan_humu_page/predict/ \
--scale 0.5 \
--confidence-score \
--attention-map \
--attention-map-level word \
......@@ -118,7 +114,6 @@ teklia-dan predict \
--parameters dan_humu_page/parameters.yml \
--charset dan_humu_page/charset.pkl \
--output dan_humu_page/predict/ \
--scale 0.5 \
--attention-map \
--predict-objects \
--threshold-method otsu
......
---
version: 0.0.1
parameters:
mean: [166.8418783515498, 166.8418783515498, 166.8418783515498]
std: [34.084189571536385, 34.084189571536385, 34.084189571536385]
max_char_prediction: 200
encoder:
input_channels: 3
......@@ -22,3 +20,7 @@ parameters:
dec_num_heads: 4
dec_att_dropout: 0.1
dec_res_dropout: 0.1
preprocessings:
- type: "max_resize"
max_height: 1500
max_width: 1500
......@@ -6,7 +6,6 @@ import pytest
from dan.predict.prediction import DAN
from dan.predict.prediction import run as run_prediction
from dan.utils import read_image
@pytest.mark.parametrize(
......@@ -45,8 +44,8 @@ def test_predict(
mode="eval",
)
image = read_image(prediction_data_path / "images" / image_name)
image = dan_model.preprocess(image)
image_path = prediction_data_path / "images" / image_name
image = dan_model.preprocess(str(image_path))
input_tensor = image.unsqueeze(0)
input_tensor = input_tensor.to(device)
......@@ -258,7 +257,6 @@ def test_run_prediction(
parameters=prediction_data_path / "parameters.yml",
charset=prediction_data_path / "charset.pkl",
output=tmp_path,
scale=1,
confidence_score=True if confidence_score else False,
confidence_score_levels=confidence_score if confidence_score else [],
attention_map=False,
......@@ -267,7 +265,6 @@ def test_run_prediction(
word_separators=[" ", "\n"],
line_separators=["\n"],
temperature=temperature,
image_max_width=None,
predict_objects=False,
threshold_method="otsu",
threshold_value=0,
......