Skip to content
Snippets Groups Projects
Commit 78281e2e authored by Manon Blanco's avatar Manon Blanco
Browse files

Type hints in dan.ocr.predict

parent 0bec214f
No related branches found
No related tags found
1 merge request!256Type hints in dan.ocr.predict
...@@ -79,7 +79,7 @@ def add_predict_parser(subcommands) -> None: ...@@ -79,7 +79,7 @@ def add_predict_parser(subcommands) -> None:
) )
parser.add_argument( parser.add_argument(
"--confidence-score-levels", "--confidence-score-levels",
default="", default=[],
type=str, type=str,
nargs="+", nargs="+",
help="Levels of confidence scores. Should be a list of any combinaison of ['char', 'word', 'line'].", help="Levels of confidence scores. Should be a list of any combinaison of ['char', 'word', 'line'].",
...@@ -132,6 +132,7 @@ def add_predict_parser(subcommands) -> None: ...@@ -132,6 +132,7 @@ def add_predict_parser(subcommands) -> None:
"--threshold-method", "--threshold-method",
help="Thresholding method.", help="Thresholding method.",
choices=["otsu", "simple"], choices=["otsu", "simple"],
type=str,
default="otsu", default="otsu",
) )
parser.add_argument( parser.add_argument(
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import re import re
from typing import List, Tuple
import cv2 import cv2
import numpy as np import numpy as np
import torch
from PIL import Image from PIL import Image
from torchvision.transforms.functional import to_pil_image from torchvision.transforms.functional import to_pil_image
from dan import logger from dan import logger
def parse_delimiters(delimiters): def parse_delimiters(delimiters: List[str]) -> re.Pattern:
return re.compile(r"|".join(delimiters)) return re.compile(r"|".join(delimiters))
def compute_prob_by_separator(characters, probabilities, separator): def compute_prob_by_separator(
characters: str, probabilities: List[float], separator: re.Pattern
) -> Tuple[List[str], List[np.float64]]:
""" """
Split text and confidences using separators and return a list of average confidence scores. Split text and confidences using separators and return a list of average confidence scores.
:param characters: list of characters. :param characters: list of characters.
...@@ -31,7 +35,9 @@ def compute_prob_by_separator(characters, probabilities, separator): ...@@ -31,7 +35,9 @@ def compute_prob_by_separator(characters, probabilities, separator):
return texts, probs return texts, probs
def split_text(text: str, level: str, word_separators, line_separators): def split_text(
text: str, level: str, word_separators: re.Pattern, line_separators: re.Pattern
) -> Tuple[List[str], int]:
""" """
Split text into a list of characters, word, or lines. Split text into a list of characters, word, or lines.
:param text: Text prediction from DAN :param text: Text prediction from DAN
...@@ -57,8 +63,12 @@ def split_text(text: str, level: str, word_separators, line_separators): ...@@ -57,8 +63,12 @@ def split_text(text: str, level: str, word_separators, line_separators):
def split_text_and_confidences( def split_text_and_confidences(
text, confidences, level, word_separators, line_separators text: str,
): confidences: List[float],
level: str,
word_separators: re.Pattern,
line_separators: re.Pattern,
) -> Tuple[List[str], List[np.float64], int]:
""" """
Split text into a list of characters, words or lines with corresponding confidences scores Split text into a list of characters, words or lines with corresponding confidences scores
:param text: Text prediction from DAN :param text: Text prediction from DAN
...@@ -86,17 +96,17 @@ def split_text_and_confidences( ...@@ -86,17 +96,17 @@ def split_text_and_confidences(
def get_predicted_polygons_with_confidence( def get_predicted_polygons_with_confidence(
text, text: str,
weights, weights: np.ndarray,
confidences, confidences: List[float],
level, level: str,
height, height: int,
width, width: int,
threshold_method="otsu", threshold_method: str = "otsu",
threshold_value=0, threshold_value: int = 0,
word_separators=["\n", " "], word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators=["\n"], line_separators: re.Pattern = parse_delimiters(["\n"]),
): ) -> List[dict]:
""" """
Returns the polygons of each object of the current prediction Returns the polygons of each object of the current prediction
:param text: Text predicted by DAN :param text: Text predicted by DAN
...@@ -135,7 +145,13 @@ def get_predicted_polygons_with_confidence( ...@@ -135,7 +145,13 @@ def get_predicted_polygons_with_confidence(
return polygons return polygons
def compute_coverage(text: str, max_value: float, offset: int, attentions, size: tuple): def compute_coverage(
text: str,
max_value: np.float32,
offset: int,
attentions: np.ndarray,
size: Tuple[int, int],
) -> np.ndarray:
""" """
Aggregates attention maps for the current text piece (char, word, line) Aggregates attention maps for the current text piece (char, word, line)
:param text: Text piece selected with offset after splitting DAN prediction :param text: Text piece selected with offset after splitting DAN prediction
...@@ -162,7 +178,9 @@ def compute_coverage(text: str, max_value: float, offset: int, attentions, size: ...@@ -162,7 +178,9 @@ def compute_coverage(text: str, max_value: float, offset: int, attentions, size:
return coverage_vector return coverage_vector
def blend_coverage(coverage_vector, image, mask, scale): def blend_coverage(
coverage_vector: np.ndarray, image: Image.Image, mask: Image.Image, scale: float
) -> Image.Image:
""" """
Blends current coverage_vector over original image, used to make an attention map. Blends current coverage_vector over original image, used to make an attention map.
:param coverage_vector: Aggregated attention weights of the current text piece, resized to image. size: (n_char, image_height, image_width) :param coverage_vector: Aggregated attention weights of the current text piece, resized to image. size: (n_char, image_height, image_width)
...@@ -184,7 +202,9 @@ def blend_coverage(coverage_vector, image, mask, scale): ...@@ -184,7 +202,9 @@ def blend_coverage(coverage_vector, image, mask, scale):
return blend return blend
def compute_contour_metrics(coverage_vector, contour): def compute_contour_metrics(
coverage_vector: np.ndarray, contour: np.ndarray
) -> Tuple[np.float64, np.float64]:
""" """
Compute the contours's area and the mean value inside it. Compute the contours's area and the mean value inside it.
:param coverage_vector: Aggregated attention weights of the current text piece, resized to image. size: (n_char, image_height, image_width) :param coverage_vector: Aggregated attention weights of the current text piece, resized to image. size: (n_char, image_height, image_width)
...@@ -199,12 +219,14 @@ def compute_contour_metrics(coverage_vector, contour): ...@@ -199,12 +219,14 @@ def compute_contour_metrics(coverage_vector, contour):
return max_value, max_value * area return max_value, max_value * area
def polygon_to_bbx(polygon): def polygon_to_bbx(polygon: np.ndarray) -> List[Tuple[int, int]]:
x, y, w, h = cv2.boundingRect(polygon) x, y, w, h = cv2.boundingRect(polygon)
return [[x, y], [x + w, y], [x + w, y + h], [x, y + h]] return [[x, y], [x + w, y], [x + w, y + h], [x, y + h]]
def threshold(mask, threshold_method="otsu", threshold_value=0): def threshold(
mask: np.ndarray, threshold_method: str = "otsu", threshold_value: int = 0
) -> np.ndarray:
""" """
Threshold a grayscale mask. Threshold a grayscale mask.
:param mask: a grayscale image (np.array) :param mask: a grayscale image (np.array)
...@@ -234,8 +256,14 @@ def threshold(mask, threshold_method="otsu", threshold_value=0): ...@@ -234,8 +256,14 @@ def threshold(mask, threshold_method="otsu", threshold_value=0):
def get_polygon( def get_polygon(
text, max_value, offset, weights, threshold_method, threshold_value, size=None text: str,
): max_value: np.float32,
offset: int,
weights: np.ndarray,
threshold_method: str,
threshold_value: int,
size: Tuple[int, int] = None,
) -> Tuple[dict, np.ndarray]:
""" """
Gets polygon associated with element of current text_piece, indexed by offset Gets polygon associated with element of current text_piece, indexed by offset
:param text: Text piece selected with offset after splitting DAN prediction :param text: Text piece selected with offset after splitting DAN prediction
...@@ -279,18 +307,18 @@ def get_polygon( ...@@ -279,18 +307,18 @@ def get_polygon(
def plot_attention( def plot_attention(
image, image: torch.Tensor,
text, text: str,
weights, weights: np.ndarray,
level, level: str,
scale, scale: float,
outname, outname: str,
threshold_method="otsu", threshold_method: str = "otsu",
threshold_value=0, threshold_value: int = 0,
word_separators=["\n", " "], word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators=["\n"], line_separators: re.Pattern = parse_delimiters(["\n"]),
display_polygons=False, display_polygons: bool = False,
): ) -> None:
""" """
Create a gif by blending attention maps to the image for each text piece (char, word or line) Create a gif by blending attention maps to the image for each text piece (char, word or line)
:param image: Input image as torch.Tensor :param image: Input image as torch.Tensor
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
import json import json
import pickle import pickle
import re
from itertools import pairwise from itertools import pairwise
from pathlib import Path from pathlib import Path
from typing import Iterable, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -34,7 +36,7 @@ class DAN: ...@@ -34,7 +36,7 @@ class DAN:
The class initializes useful parameters: the device and the temperature scalar parameter. The class initializes useful parameters: the device and the temperature scalar parameter.
""" """
def __init__(self, device, temperature=1.0): def __init__(self, device: str, temperature=1.0) -> None:
""" """
Constructor of the DAN class. Constructor of the DAN class.
:param device: The device to use. :param device: The device to use.
...@@ -44,8 +46,12 @@ class DAN: ...@@ -44,8 +46,12 @@ class DAN:
self.temperature = temperature self.temperature = temperature
def load( def load(
self, model_path: Path, params_path: Path, charset_path: Path, mode="eval" self,
): model_path: Path,
params_path: Path,
charset_path: Path,
mode: str = "eval",
) -> None:
""" """
Load a trained model. Load a trained model.
:param model_path: Path to the model. :param model_path: Path to the model.
...@@ -88,7 +94,7 @@ class DAN: ...@@ -88,7 +94,7 @@ class DAN:
) )
self.max_chars = parameters["max_char_prediction"] self.max_chars = parameters["max_char_prediction"]
def preprocess(self, path): def preprocess(self, path: str) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Preprocess an image. Preprocess an image.
:param path: Path of the image to load and preprocess. :param path: Path of the image to load and preprocess.
...@@ -104,18 +110,18 @@ class DAN: ...@@ -104,18 +110,18 @@ class DAN:
def predict( def predict(
self, self,
input_tensor, input_tensor: torch.Tensor,
input_sizes, input_sizes: List[torch.Size],
confidences=False, confidences: bool = False,
attentions=False, attentions: bool = False,
attention_level=False, attention_level: str = "line",
extract_objects=False, extract_objects: bool = False,
word_separators=["\n", " "], word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators=["\n"], line_separators: re.Pattern = parse_delimiters(["\n"]),
start_token=None, start_token: str = None,
threshold_method="otsu", threshold_method: str = "otsu",
threshold_value=0, threshold_value: int = 0,
): ) -> dict:
""" """
Run prediction on an input image. Run prediction on an input image.
:param input_tensor: A batch of images to predict. :param input_tensor: A batch of images to predict.
...@@ -254,7 +260,9 @@ class DAN: ...@@ -254,7 +260,9 @@ class DAN:
return out return out
def parse_ner_predictions(text, char_confidences, predictions): def parse_ner_predictions(
text: str, char_confidences: List[float], predictions: Iterable[Tuple[int, int]]
) -> List[dict]:
return [ return [
{ {
"text": f"{text[current: next_token]}".replace("\n", " "), "text": f"{text[current: next_token]}".replace("\n", " "),
...@@ -265,22 +273,22 @@ def parse_ner_predictions(text, char_confidences, predictions): ...@@ -265,22 +273,22 @@ def parse_ner_predictions(text, char_confidences, predictions):
def process_batch( def process_batch(
image_batch, image_batch: List[Path],
dan_model, dan_model: DAN,
device, device: str,
output, output: Path,
confidence_score, confidence_score: bool,
confidence_score_levels, confidence_score_levels: List[str],
attention_map, attention_map: bool,
attention_map_level, attention_map_level: str,
attention_map_scale, attention_map_scale: float,
word_separators, word_separators: List[str],
line_separators, line_separators: List[str],
predict_objects, predict_objects: bool,
threshold_method, threshold_method: str,
threshold_value, threshold_value: int,
tokens, tokens: Path,
): ) -> None:
input_images, visu_images, input_sizes = [], [], [] input_images, visu_images, input_sizes = [], [], []
logger.info("Loading images...") logger.info("Loading images...")
for image_path in image_batch: for image_path in image_batch:
...@@ -394,28 +402,28 @@ def process_batch( ...@@ -394,28 +402,28 @@ def process_batch(
def run( def run(
image, image: Optional[Path],
image_dir, image_dir: Optional[Path],
model, model: Path,
parameters, parameters: Path,
charset, charset: Path,
output, output: Path,
confidence_score, confidence_score: bool,
confidence_score_levels, confidence_score_levels: List[str],
attention_map, attention_map: bool,
attention_map_level, attention_map_level: str,
attention_map_scale, attention_map_scale: float,
word_separators, word_separators: List[str],
line_separators, line_separators: List[str],
temperature, temperature: float,
predict_objects, predict_objects: bool,
threshold_method, threshold_method: str,
threshold_value, threshold_value: int,
image_extension, image_extension: str,
gpu_device, gpu_device: int,
batch_size, batch_size: int,
tokens, tokens: Path,
): ) -> None:
""" """
Predict a single image save the output Predict a single image save the output
:param image: Path to the image to predict. :param image: Path to the image to predict.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment