Skip to content
Snippets Groups Projects
Commit 3626aa70 authored by Marie Generali's avatar Marie Generali :worried: Committed by Yoann Schneider
Browse files

Implement temperature scaling on dan

parent fca6f2d1
No related branches found
No related tags found
1 merge request!146Implement temperature scaling on dan
...@@ -55,6 +55,20 @@ def add_predict_parser(subcommands) -> None: ...@@ -55,6 +55,20 @@ def add_predict_parser(subcommands) -> None:
required=False, required=False,
help="Image scaling factor before feeding it to DAN", help="Image scaling factor before feeding it to DAN",
) )
parser.add_argument(
"--image-max-width",
type=int,
default=1800,
required=False,
help="Image resizing before feeding it to DAN",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="Temperature scaling scalar parameter",
required=True,
)
parser.add_argument( parser.add_argument(
"--confidence-score", "--confidence-score",
action="store_true", action="store_true",
......
...@@ -20,22 +20,23 @@ from dan.predict.attention import ( ...@@ -20,22 +20,23 @@ from dan.predict.attention import (
plot_attention, plot_attention,
split_text_and_confidences, split_text_and_confidences,
) )
from dan.utils import read_image from dan.utils import pairwise, read_image
class DAN: class DAN:
""" """
The DAN class is used to apply a DAN model. The DAN class is used to apply a DAN model.
The class initializes useful parameters: the device. The class initializes useful parameters: the device and the temperature scalar parameter.
""" """
def __init__(self, device): def __init__(self, device, temperature=1.0):
""" """
Constructor of the DAN class. Constructor of the DAN class.
:param device: The device to use. :param device: The device to use.
""" """
super(DAN, self).__init__() super(DAN, self).__init__()
self.device = device self.device = device
self.temperature = temperature
def load(self, model_path, params_path, charset_path, mode="eval"): def load(self, model_path, params_path, charset_path, mode="eval"):
""" """
...@@ -158,7 +159,13 @@ class DAN: ...@@ -158,7 +159,13 @@ class DAN:
).permute(2, 0, 1) ).permute(2, 0, 1)
for i in range(0, self.max_chars): for i in range(0, self.max_chars):
output, pred, hidden_predict, cache, weights = self.decoder( (
output,
pred,
hidden_predict,
cache,
weights,
) = self.decoder(
features, features,
enhanced_features, enhanced_features,
predicted_tokens, predicted_tokens,
...@@ -170,6 +177,8 @@ class DAN: ...@@ -170,6 +177,8 @@ class DAN:
cache=cache, cache=cache,
num_pred=1, num_pred=1,
) )
pred = pred / self.temperature
whole_output.append(output) whole_output.append(output)
attention_maps.append(weights) attention_maps.append(weights)
confidence_scores.append( confidence_scores.append(
...@@ -256,6 +265,8 @@ def run( ...@@ -256,6 +265,8 @@ def run(
attention_map_scale, attention_map_scale,
word_separators, word_separators,
line_separators, line_separators,
temperature,
image_max_width,
predict_objects, predict_objects,
threshold_method, threshold_method,
threshold_value, threshold_value,
...@@ -274,6 +285,7 @@ def run( ...@@ -274,6 +285,7 @@ def run(
:param attention_map_scale: Scaling factor for the attention map. :param attention_map_scale: Scaling factor for the attention map.
:param word_separators: List of word separators. :param word_separators: List of word separators.
:param line_separators: List of line separators. :param line_separators: List of line separators.
:param image_max_width: Resize image
:param predict_objects: Whether to extract objects. :param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. :param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method. :param threshold_value: Thresholding value to use for the "simple" thresholding method.
...@@ -284,11 +296,17 @@ def run( ...@@ -284,11 +296,17 @@ def run(
# Load model # Load model
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device) dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="eval") dan_model.load(model, parameters, charset, mode="eval")
# Load image and pre-process it # Load image and pre-process it
im = read_image(image, scale=scale) if image_max_width:
_, w, _ = read_image(image, scale=1).shape
ratio = image_max_width / w
im = read_image(image, ratio)
else:
im = read_image(image, scale=scale)
logger.info("Image loaded.") logger.info("Image loaded.")
im_p = dan_model.preprocess(im) im_p = dan_model.preprocess(im)
logger.debug("Image pre-processed.") logger.debug("Image pre-processed.")
...@@ -326,8 +344,20 @@ def run( ...@@ -326,8 +344,20 @@ def run(
# Return mean confidence score # Return mean confidence score
if confidence_score: if confidence_score:
result["confidences"] = {} result["confidences"] = {}
char_confidences = prediction["confidences"][0] char_confidences = prediction["confidences"][0]
text = result["text"]
# retrieve the index of the token ner
index = [pos for pos, char in enumerate(text) if char in ["", "", "", ""]]
# calculates scores by token
result["confidences"]["by ner token"] = [
{
"text": f"{text[current: next_token-1]}",
"confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token-1]), 2)}",
}
for current, next_token in pairwise(index + [0])
]
result["confidences"]["total"] = np.around(np.mean(char_confidences), 2) result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
for level in confidence_score_levels: for level in confidence_score_levels:
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from itertools import tee
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
...@@ -152,3 +154,13 @@ def round_floats(float_list, decimals=2): ...@@ -152,3 +154,13 @@ def round_floats(float_list, decimals=2):
Round list of floats with fixed decimals Round list of floats with fixed decimals
""" """
return [np.around(num, decimals) for num in float_list] return [np.around(num, decimals) for num in float_list]
def pairwise(iterable):
"""
Not necessary when using 3.10. See https://docs.python.org/3/library/itertools.html#itertools.pairwise.
"""
# pairwise('ABCDEFG') --> AB BC CD DE EF FG
a, b = tee(iterable)
next(b, None)
return zip(a, b)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment