Skip to content
Snippets Groups Projects
Commit 3626aa70 authored by Marie Generali's avatar Marie Generali :worried: Committed by Yoann Schneider
Browse files

Implement temperature scaling on dan

parent fca6f2d1
No related branches found
No related tags found
1 merge request!146Implement temperature scaling on dan
......@@ -55,6 +55,20 @@ def add_predict_parser(subcommands) -> None:
required=False,
help="Image scaling factor before feeding it to DAN",
)
parser.add_argument(
"--image-max-width",
type=int,
default=1800,
required=False,
help="Image resizing before feeding it to DAN",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="Temperature scaling scalar parameter",
required=True,
)
parser.add_argument(
"--confidence-score",
action="store_true",
......
......@@ -20,22 +20,23 @@ from dan.predict.attention import (
plot_attention,
split_text_and_confidences,
)
from dan.utils import read_image
from dan.utils import pairwise, read_image
class DAN:
"""
The DAN class is used to apply a DAN model.
The class initializes useful parameters: the device.
The class initializes useful parameters: the device and the temperature scalar parameter.
"""
def __init__(self, device):
def __init__(self, device, temperature=1.0):
"""
Constructor of the DAN class.
:param device: The device to use.
"""
super(DAN, self).__init__()
self.device = device
self.temperature = temperature
def load(self, model_path, params_path, charset_path, mode="eval"):
"""
......@@ -158,7 +159,13 @@ class DAN:
).permute(2, 0, 1)
for i in range(0, self.max_chars):
output, pred, hidden_predict, cache, weights = self.decoder(
(
output,
pred,
hidden_predict,
cache,
weights,
) = self.decoder(
features,
enhanced_features,
predicted_tokens,
......@@ -170,6 +177,8 @@ class DAN:
cache=cache,
num_pred=1,
)
pred = pred / self.temperature
whole_output.append(output)
attention_maps.append(weights)
confidence_scores.append(
......@@ -256,6 +265,8 @@ def run(
attention_map_scale,
word_separators,
line_separators,
temperature,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
......@@ -274,6 +285,7 @@ def run(
:param attention_map_scale: Scaling factor for the attention map.
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param image_max_width: Resize image
:param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
......@@ -284,11 +296,17 @@ def run(
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device)
dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="eval")
# Load image and pre-process it
im = read_image(image, scale=scale)
if image_max_width:
_, w, _ = read_image(image, scale=1).shape
ratio = image_max_width / w
im = read_image(image, ratio)
else:
im = read_image(image, scale=scale)
logger.info("Image loaded.")
im_p = dan_model.preprocess(im)
logger.debug("Image pre-processed.")
......@@ -326,8 +344,20 @@ def run(
# Return mean confidence score
if confidence_score:
result["confidences"] = {}
char_confidences = prediction["confidences"][0]
text = result["text"]
# retrieve the index of the token ner
index = [pos for pos, char in enumerate(text) if char in ["", "", "", ""]]
# calculates scores by token
result["confidences"]["by ner token"] = [
{
"text": f"{text[current: next_token-1]}",
"confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token-1]), 2)}",
}
for current, next_token in pairwise(index + [0])
]
result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
for level in confidence_score_levels:
......
# -*- coding: utf-8 -*-
from itertools import tee
import cv2
import numpy as np
import torch
......@@ -152,3 +154,13 @@ def round_floats(float_list, decimals=2):
Round list of floats with fixed decimals
"""
return [np.around(num, decimals) for num in float_list]
def pairwise(iterable):
"""
Not necessary when using 3.10. See https://docs.python.org/3/library/itertools.html#itertools.pairwise.
"""
# pairwise('ABCDEFG') --> AB BC CD DE EF FG
a, b = tee(iterable)
next(b, None)
return zip(a, b)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment