Skip to content
Snippets Groups Projects
Commit 78281e2e authored by Manon Blanco's avatar Manon Blanco
Browse files

Type hints in dan.ocr.predict

parent 0bec214f
No related branches found
No related tags found
1 merge request!256Type hints in dan.ocr.predict
......@@ -79,7 +79,7 @@ def add_predict_parser(subcommands) -> None:
)
parser.add_argument(
"--confidence-score-levels",
default="",
default=[],
type=str,
nargs="+",
help="Levels of confidence scores. Should be a list of any combinaison of ['char', 'word', 'line'].",
......@@ -132,6 +132,7 @@ def add_predict_parser(subcommands) -> None:
"--threshold-method",
help="Thresholding method.",
choices=["otsu", "simple"],
type=str,
default="otsu",
)
parser.add_argument(
......
# -*- coding: utf-8 -*-
import re
from typing import List, Tuple
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from dan import logger
def parse_delimiters(delimiters):
def parse_delimiters(delimiters: List[str]) -> re.Pattern:
return re.compile(r"|".join(delimiters))
def compute_prob_by_separator(characters, probabilities, separator):
def compute_prob_by_separator(
characters: str, probabilities: List[float], separator: re.Pattern
) -> Tuple[List[str], List[np.float64]]:
"""
Split text and confidences using separators and return a list of average confidence scores.
:param characters: list of characters.
......@@ -31,7 +35,9 @@ def compute_prob_by_separator(characters, probabilities, separator):
return texts, probs
def split_text(text: str, level: str, word_separators, line_separators):
def split_text(
text: str, level: str, word_separators: re.Pattern, line_separators: re.Pattern
) -> Tuple[List[str], int]:
"""
Split text into a list of characters, word, or lines.
:param text: Text prediction from DAN
......@@ -57,8 +63,12 @@ def split_text(text: str, level: str, word_separators, line_separators):
def split_text_and_confidences(
text, confidences, level, word_separators, line_separators
):
text: str,
confidences: List[float],
level: str,
word_separators: re.Pattern,
line_separators: re.Pattern,
) -> Tuple[List[str], List[np.float64], int]:
"""
Split text into a list of characters, words or lines with corresponding confidences scores
:param text: Text prediction from DAN
......@@ -86,17 +96,17 @@ def split_text_and_confidences(
def get_predicted_polygons_with_confidence(
text,
weights,
confidences,
level,
height,
width,
threshold_method="otsu",
threshold_value=0,
word_separators=["\n", " "],
line_separators=["\n"],
):
text: str,
weights: np.ndarray,
confidences: List[float],
level: str,
height: int,
width: int,
threshold_method: str = "otsu",
threshold_value: int = 0,
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
) -> List[dict]:
"""
Returns the polygons of each object of the current prediction
:param text: Text predicted by DAN
......@@ -135,7 +145,13 @@ def get_predicted_polygons_with_confidence(
return polygons
def compute_coverage(text: str, max_value: float, offset: int, attentions, size: tuple):
def compute_coverage(
text: str,
max_value: np.float32,
offset: int,
attentions: np.ndarray,
size: Tuple[int, int],
) -> np.ndarray:
"""
Aggregates attention maps for the current text piece (char, word, line)
:param text: Text piece selected with offset after splitting DAN prediction
......@@ -162,7 +178,9 @@ def compute_coverage(text: str, max_value: float, offset: int, attentions, size:
return coverage_vector
def blend_coverage(coverage_vector, image, mask, scale):
def blend_coverage(
coverage_vector: np.ndarray, image: Image.Image, mask: Image.Image, scale: float
) -> Image.Image:
"""
Blends current coverage_vector over original image, used to make an attention map.
:param coverage_vector: Aggregated attention weights of the current text piece, resized to image. size: (n_char, image_height, image_width)
......@@ -184,7 +202,9 @@ def blend_coverage(coverage_vector, image, mask, scale):
return blend
def compute_contour_metrics(coverage_vector, contour):
def compute_contour_metrics(
coverage_vector: np.ndarray, contour: np.ndarray
) -> Tuple[np.float64, np.float64]:
"""
Compute the contours's area and the mean value inside it.
:param coverage_vector: Aggregated attention weights of the current text piece, resized to image. size: (n_char, image_height, image_width)
......@@ -199,12 +219,14 @@ def compute_contour_metrics(coverage_vector, contour):
return max_value, max_value * area
def polygon_to_bbx(polygon):
def polygon_to_bbx(polygon: np.ndarray) -> List[Tuple[int, int]]:
x, y, w, h = cv2.boundingRect(polygon)
return [[x, y], [x + w, y], [x + w, y + h], [x, y + h]]
def threshold(mask, threshold_method="otsu", threshold_value=0):
def threshold(
mask: np.ndarray, threshold_method: str = "otsu", threshold_value: int = 0
) -> np.ndarray:
"""
Threshold a grayscale mask.
:param mask: a grayscale image (np.array)
......@@ -234,8 +256,14 @@ def threshold(mask, threshold_method="otsu", threshold_value=0):
def get_polygon(
text, max_value, offset, weights, threshold_method, threshold_value, size=None
):
text: str,
max_value: np.float32,
offset: int,
weights: np.ndarray,
threshold_method: str,
threshold_value: int,
size: Tuple[int, int] = None,
) -> Tuple[dict, np.ndarray]:
"""
Gets polygon associated with element of current text_piece, indexed by offset
:param text: Text piece selected with offset after splitting DAN prediction
......@@ -279,18 +307,18 @@ def get_polygon(
def plot_attention(
image,
text,
weights,
level,
scale,
outname,
threshold_method="otsu",
threshold_value=0,
word_separators=["\n", " "],
line_separators=["\n"],
display_polygons=False,
):
image: torch.Tensor,
text: str,
weights: np.ndarray,
level: str,
scale: float,
outname: str,
threshold_method: str = "otsu",
threshold_value: int = 0,
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
display_polygons: bool = False,
) -> None:
"""
Create a gif by blending attention maps to the image for each text piece (char, word or line)
:param image: Input image as torch.Tensor
......
......@@ -2,8 +2,10 @@
import json
import pickle
import re
from itertools import pairwise
from pathlib import Path
from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
......@@ -34,7 +36,7 @@ class DAN:
The class initializes useful parameters: the device and the temperature scalar parameter.
"""
def __init__(self, device, temperature=1.0):
def __init__(self, device: str, temperature=1.0) -> None:
"""
Constructor of the DAN class.
:param device: The device to use.
......@@ -44,8 +46,12 @@ class DAN:
self.temperature = temperature
def load(
self, model_path: Path, params_path: Path, charset_path: Path, mode="eval"
):
self,
model_path: Path,
params_path: Path,
charset_path: Path,
mode: str = "eval",
) -> None:
"""
Load a trained model.
:param model_path: Path to the model.
......@@ -88,7 +94,7 @@ class DAN:
)
self.max_chars = parameters["max_char_prediction"]
def preprocess(self, path):
def preprocess(self, path: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Preprocess an image.
:param path: Path of the image to load and preprocess.
......@@ -104,18 +110,18 @@ class DAN:
def predict(
self,
input_tensor,
input_sizes,
confidences=False,
attentions=False,
attention_level=False,
extract_objects=False,
word_separators=["\n", " "],
line_separators=["\n"],
start_token=None,
threshold_method="otsu",
threshold_value=0,
):
input_tensor: torch.Tensor,
input_sizes: List[torch.Size],
confidences: bool = False,
attentions: bool = False,
attention_level: str = "line",
extract_objects: bool = False,
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
start_token: str = None,
threshold_method: str = "otsu",
threshold_value: int = 0,
) -> dict:
"""
Run prediction on an input image.
:param input_tensor: A batch of images to predict.
......@@ -254,7 +260,9 @@ class DAN:
return out
def parse_ner_predictions(text, char_confidences, predictions):
def parse_ner_predictions(
text: str, char_confidences: List[float], predictions: Iterable[Tuple[int, int]]
) -> List[dict]:
return [
{
"text": f"{text[current: next_token]}".replace("\n", " "),
......@@ -265,22 +273,22 @@ def parse_ner_predictions(text, char_confidences, predictions):
def process_batch(
image_batch,
dan_model,
device,
output,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
word_separators,
line_separators,
predict_objects,
threshold_method,
threshold_value,
tokens,
):
image_batch: List[Path],
dan_model: DAN,
device: str,
output: Path,
confidence_score: bool,
confidence_score_levels: List[str],
attention_map: bool,
attention_map_level: str,
attention_map_scale: float,
word_separators: List[str],
line_separators: List[str],
predict_objects: bool,
threshold_method: str,
threshold_value: int,
tokens: Path,
) -> None:
input_images, visu_images, input_sizes = [], [], []
logger.info("Loading images...")
for image_path in image_batch:
......@@ -394,28 +402,28 @@ def process_batch(
def run(
image,
image_dir,
model,
parameters,
charset,
output,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
word_separators,
line_separators,
temperature,
predict_objects,
threshold_method,
threshold_value,
image_extension,
gpu_device,
batch_size,
tokens,
):
image: Optional[Path],
image_dir: Optional[Path],
model: Path,
parameters: Path,
charset: Path,
output: Path,
confidence_score: bool,
confidence_score_levels: List[str],
attention_map: bool,
attention_map_level: str,
attention_map_scale: float,
word_separators: List[str],
line_separators: List[str],
temperature: float,
predict_objects: bool,
threshold_method: str,
threshold_value: int,
image_extension: str,
gpu_device: int,
batch_size: int,
tokens: Path,
) -> None:
"""
Predict a single image save the output
:param image: Path to the image to predict.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment