diff --git a/dan/ocr/predict/__init__.py b/dan/ocr/predict/__init__.py index d535a0c541ef54639d0f5fa6d2d82749fee3f18c..1c65254a96a7a661d0b73a2a92ee55260e5aff82 100644 --- a/dan/ocr/predict/__init__.py +++ b/dan/ocr/predict/__init__.py @@ -79,7 +79,7 @@ def add_predict_parser(subcommands) -> None: ) parser.add_argument( "--confidence-score-levels", - default="", + default=[], type=str, nargs="+", help="Levels of confidence scores. Should be a list of any combinaison of ['char', 'word', 'line'].", @@ -132,6 +132,7 @@ def add_predict_parser(subcommands) -> None: "--threshold-method", help="Thresholding method.", choices=["otsu", "simple"], + type=str, default="otsu", ) parser.add_argument( diff --git a/dan/ocr/predict/attention.py b/dan/ocr/predict/attention.py index 6c569f56630542c9f0df09ca50c084e20eae7305..20b1adff8c98fef04539e5cada55125cd5e72ac6 100644 --- a/dan/ocr/predict/attention.py +++ b/dan/ocr/predict/attention.py @@ -1,19 +1,23 @@ # -*- coding: utf-8 -*- import re +from typing import List, Tuple import cv2 import numpy as np +import torch from PIL import Image from torchvision.transforms.functional import to_pil_image from dan import logger -def parse_delimiters(delimiters): +def parse_delimiters(delimiters: List[str]) -> re.Pattern: return re.compile(r"|".join(delimiters)) -def compute_prob_by_separator(characters, probabilities, separator): +def compute_prob_by_separator( + characters: str, probabilities: List[float], separator: re.Pattern +) -> Tuple[List[str], List[np.float64]]: """ Split text and confidences using separators and return a list of average confidence scores. :param characters: list of characters. @@ -31,7 +35,9 @@ def compute_prob_by_separator(characters, probabilities, separator): return texts, probs -def split_text(text: str, level: str, word_separators, line_separators): +def split_text( + text: str, level: str, word_separators: re.Pattern, line_separators: re.Pattern +) -> Tuple[List[str], int]: """ Split text into a list of characters, word, or lines. :param text: Text prediction from DAN @@ -57,8 +63,12 @@ def split_text(text: str, level: str, word_separators, line_separators): def split_text_and_confidences( - text, confidences, level, word_separators, line_separators -): + text: str, + confidences: List[float], + level: str, + word_separators: re.Pattern, + line_separators: re.Pattern, +) -> Tuple[List[str], List[np.float64], int]: """ Split text into a list of characters, words or lines with corresponding confidences scores :param text: Text prediction from DAN @@ -86,17 +96,17 @@ def split_text_and_confidences( def get_predicted_polygons_with_confidence( - text, - weights, - confidences, - level, - height, - width, - threshold_method="otsu", - threshold_value=0, - word_separators=["\n", " "], - line_separators=["\n"], -): + text: str, + weights: np.ndarray, + confidences: List[float], + level: str, + height: int, + width: int, + threshold_method: str = "otsu", + threshold_value: int = 0, + word_separators: re.Pattern = parse_delimiters(["\n", " "]), + line_separators: re.Pattern = parse_delimiters(["\n"]), +) -> List[dict]: """ Returns the polygons of each object of the current prediction :param text: Text predicted by DAN @@ -135,7 +145,13 @@ def get_predicted_polygons_with_confidence( return polygons -def compute_coverage(text: str, max_value: float, offset: int, attentions, size: tuple): +def compute_coverage( + text: str, + max_value: np.float32, + offset: int, + attentions: np.ndarray, + size: Tuple[int, int], +) -> np.ndarray: """ Aggregates attention maps for the current text piece (char, word, line) :param text: Text piece selected with offset after splitting DAN prediction @@ -162,7 +178,9 @@ def compute_coverage(text: str, max_value: float, offset: int, attentions, size: return coverage_vector -def blend_coverage(coverage_vector, image, mask, scale): +def blend_coverage( + coverage_vector: np.ndarray, image: Image.Image, mask: Image.Image, scale: float +) -> Image.Image: """ Blends current coverage_vector over original image, used to make an attention map. :param coverage_vector: Aggregated attention weights of the current text piece, resized to image. size: (n_char, image_height, image_width) @@ -184,7 +202,9 @@ def blend_coverage(coverage_vector, image, mask, scale): return blend -def compute_contour_metrics(coverage_vector, contour): +def compute_contour_metrics( + coverage_vector: np.ndarray, contour: np.ndarray +) -> Tuple[np.float64, np.float64]: """ Compute the contours's area and the mean value inside it. :param coverage_vector: Aggregated attention weights of the current text piece, resized to image. size: (n_char, image_height, image_width) @@ -199,12 +219,14 @@ def compute_contour_metrics(coverage_vector, contour): return max_value, max_value * area -def polygon_to_bbx(polygon): +def polygon_to_bbx(polygon: np.ndarray) -> List[Tuple[int, int]]: x, y, w, h = cv2.boundingRect(polygon) return [[x, y], [x + w, y], [x + w, y + h], [x, y + h]] -def threshold(mask, threshold_method="otsu", threshold_value=0): +def threshold( + mask: np.ndarray, threshold_method: str = "otsu", threshold_value: int = 0 +) -> np.ndarray: """ Threshold a grayscale mask. :param mask: a grayscale image (np.array) @@ -234,8 +256,14 @@ def threshold(mask, threshold_method="otsu", threshold_value=0): def get_polygon( - text, max_value, offset, weights, threshold_method, threshold_value, size=None -): + text: str, + max_value: np.float32, + offset: int, + weights: np.ndarray, + threshold_method: str, + threshold_value: int, + size: Tuple[int, int] = None, +) -> Tuple[dict, np.ndarray]: """ Gets polygon associated with element of current text_piece, indexed by offset :param text: Text piece selected with offset after splitting DAN prediction @@ -279,18 +307,18 @@ def get_polygon( def plot_attention( - image, - text, - weights, - level, - scale, - outname, - threshold_method="otsu", - threshold_value=0, - word_separators=["\n", " "], - line_separators=["\n"], - display_polygons=False, -): + image: torch.Tensor, + text: str, + weights: np.ndarray, + level: str, + scale: float, + outname: str, + threshold_method: str = "otsu", + threshold_value: int = 0, + word_separators: re.Pattern = parse_delimiters(["\n", " "]), + line_separators: re.Pattern = parse_delimiters(["\n"]), + display_polygons: bool = False, +) -> None: """ Create a gif by blending attention maps to the image for each text piece (char, word or line) :param image: Input image as torch.Tensor diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py index faa3383906e07d385cafb7c2efd40b890eab1219..32758f8a29553bffc717c93c896eeb2ad6996010 100644 --- a/dan/ocr/predict/prediction.py +++ b/dan/ocr/predict/prediction.py @@ -2,8 +2,10 @@ import json import pickle +import re from itertools import pairwise from pathlib import Path +from typing import Iterable, List, Optional, Tuple import numpy as np import torch @@ -34,7 +36,7 @@ class DAN: The class initializes useful parameters: the device and the temperature scalar parameter. """ - def __init__(self, device, temperature=1.0): + def __init__(self, device: str, temperature=1.0) -> None: """ Constructor of the DAN class. :param device: The device to use. @@ -44,8 +46,12 @@ class DAN: self.temperature = temperature def load( - self, model_path: Path, params_path: Path, charset_path: Path, mode="eval" - ): + self, + model_path: Path, + params_path: Path, + charset_path: Path, + mode: str = "eval", + ) -> None: """ Load a trained model. :param model_path: Path to the model. @@ -88,7 +94,7 @@ class DAN: ) self.max_chars = parameters["max_char_prediction"] - def preprocess(self, path): + def preprocess(self, path: str) -> Tuple[torch.Tensor, torch.Tensor]: """ Preprocess an image. :param path: Path of the image to load and preprocess. @@ -104,18 +110,18 @@ class DAN: def predict( self, - input_tensor, - input_sizes, - confidences=False, - attentions=False, - attention_level=False, - extract_objects=False, - word_separators=["\n", " "], - line_separators=["\n"], - start_token=None, - threshold_method="otsu", - threshold_value=0, - ): + input_tensor: torch.Tensor, + input_sizes: List[torch.Size], + confidences: bool = False, + attentions: bool = False, + attention_level: str = "line", + extract_objects: bool = False, + word_separators: re.Pattern = parse_delimiters(["\n", " "]), + line_separators: re.Pattern = parse_delimiters(["\n"]), + start_token: str = None, + threshold_method: str = "otsu", + threshold_value: int = 0, + ) -> dict: """ Run prediction on an input image. :param input_tensor: A batch of images to predict. @@ -254,7 +260,9 @@ class DAN: return out -def parse_ner_predictions(text, char_confidences, predictions): +def parse_ner_predictions( + text: str, char_confidences: List[float], predictions: Iterable[Tuple[int, int]] +) -> List[dict]: return [ { "text": f"{text[current: next_token]}".replace("\n", " "), @@ -265,22 +273,22 @@ def parse_ner_predictions(text, char_confidences, predictions): def process_batch( - image_batch, - dan_model, - device, - output, - confidence_score, - confidence_score_levels, - attention_map, - attention_map_level, - attention_map_scale, - word_separators, - line_separators, - predict_objects, - threshold_method, - threshold_value, - tokens, -): + image_batch: List[Path], + dan_model: DAN, + device: str, + output: Path, + confidence_score: bool, + confidence_score_levels: List[str], + attention_map: bool, + attention_map_level: str, + attention_map_scale: float, + word_separators: List[str], + line_separators: List[str], + predict_objects: bool, + threshold_method: str, + threshold_value: int, + tokens: Path, +) -> None: input_images, visu_images, input_sizes = [], [], [] logger.info("Loading images...") for image_path in image_batch: @@ -394,28 +402,28 @@ def process_batch( def run( - image, - image_dir, - model, - parameters, - charset, - output, - confidence_score, - confidence_score_levels, - attention_map, - attention_map_level, - attention_map_scale, - word_separators, - line_separators, - temperature, - predict_objects, - threshold_method, - threshold_value, - image_extension, - gpu_device, - batch_size, - tokens, -): + image: Optional[Path], + image_dir: Optional[Path], + model: Path, + parameters: Path, + charset: Path, + output: Path, + confidence_score: bool, + confidence_score_levels: List[str], + attention_map: bool, + attention_map_level: str, + attention_map_scale: float, + word_separators: List[str], + line_separators: List[str], + temperature: float, + predict_objects: bool, + threshold_method: str, + threshold_value: int, + image_extension: str, + gpu_device: int, + batch_size: int, + tokens: Path, +) -> None: """ Predict a single image save the output :param image: Path to the image to predict.