From 3626aa70ceae0e4e22ea4ad16162f2c93f00654e Mon Sep 17 00:00:00 2001 From: M Generali <mgenerali@teklia.com> Date: Fri, 9 Jun 2023 10:34:28 +0000 Subject: [PATCH] Implement temperature scaling on dan --- dan/predict/__init__.py | 14 +++++++++++++ dan/predict/prediction.py | 44 ++++++++++++++++++++++++++++++++------- dan/utils.py | 12 +++++++++++ 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py index 0c17ffd2..44537752 100644 --- a/dan/predict/__init__.py +++ b/dan/predict/__init__.py @@ -55,6 +55,20 @@ def add_predict_parser(subcommands) -> None: required=False, help="Image scaling factor before feeding it to DAN", ) + parser.add_argument( + "--image-max-width", + type=int, + default=1800, + required=False, + help="Image resizing before feeding it to DAN", + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="Temperature scaling scalar parameter", + required=True, + ) parser.add_argument( "--confidence-score", action="store_true", diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index 0de25443..ce25c133 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -20,22 +20,23 @@ from dan.predict.attention import ( plot_attention, split_text_and_confidences, ) -from dan.utils import read_image +from dan.utils import pairwise, read_image class DAN: """ The DAN class is used to apply a DAN model. - The class initializes useful parameters: the device. + The class initializes useful parameters: the device and the temperature scalar parameter. """ - def __init__(self, device): + def __init__(self, device, temperature=1.0): """ Constructor of the DAN class. :param device: The device to use. """ super(DAN, self).__init__() self.device = device + self.temperature = temperature def load(self, model_path, params_path, charset_path, mode="eval"): """ @@ -158,7 +159,13 @@ class DAN: ).permute(2, 0, 1) for i in range(0, self.max_chars): - output, pred, hidden_predict, cache, weights = self.decoder( + ( + output, + pred, + hidden_predict, + cache, + weights, + ) = self.decoder( features, enhanced_features, predicted_tokens, @@ -170,6 +177,8 @@ class DAN: cache=cache, num_pred=1, ) + + pred = pred / self.temperature whole_output.append(output) attention_maps.append(weights) confidence_scores.append( @@ -256,6 +265,8 @@ def run( attention_map_scale, word_separators, line_separators, + temperature, + image_max_width, predict_objects, threshold_method, threshold_value, @@ -274,6 +285,7 @@ def run( :param attention_map_scale: Scaling factor for the attention map. :param word_separators: List of word separators. :param line_separators: List of line separators. + :param image_max_width: Resize image :param predict_objects: Whether to extract objects. :param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. :param threshold_value: Thresholding value to use for the "simple" thresholding method. @@ -284,11 +296,17 @@ def run( # Load model device = "cuda" if torch.cuda.is_available() else "cpu" - dan_model = DAN(device) + dan_model = DAN(device, temperature) dan_model.load(model, parameters, charset, mode="eval") # Load image and pre-process it - im = read_image(image, scale=scale) + if image_max_width: + _, w, _ = read_image(image, scale=1).shape + ratio = image_max_width / w + im = read_image(image, ratio) + else: + im = read_image(image, scale=scale) + logger.info("Image loaded.") im_p = dan_model.preprocess(im) logger.debug("Image pre-processed.") @@ -326,8 +344,20 @@ def run( # Return mean confidence score if confidence_score: result["confidences"] = {} - char_confidences = prediction["confidences"][0] + text = result["text"] + # retrieve the index of the token ner + index = [pos for pos, char in enumerate(text) if char in ["â“", "â“Ÿ", "â““", "â“¡"]] + + # calculates scores by token + + result["confidences"]["by ner token"] = [ + { + "text": f"{text[current: next_token-1]}", + "confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token-1]), 2)}", + } + for current, next_token in pairwise(index + [0]) + ] result["confidences"]["total"] = np.around(np.mean(char_confidences), 2) for level in confidence_score_levels: diff --git a/dan/utils.py b/dan/utils.py index 93243fcf..7325b2b3 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from itertools import tee + import cv2 import numpy as np import torch @@ -152,3 +154,13 @@ def round_floats(float_list, decimals=2): Round list of floats with fixed decimals """ return [np.around(num, decimals) for num in float_list] + + +def pairwise(iterable): + """ + Not necessary when using 3.10. See https://docs.python.org/3/library/itertools.html#itertools.pairwise. + """ + # pairwise('ABCDEFG') --> AB BC CD DE EF FG + a, b = tee(iterable) + next(b, None) + return zip(a, b) -- GitLab