Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (9)
......@@ -452,6 +452,7 @@ class GlobalHTADecoder(Module):
weights = torch.sum(weights, dim=1, keepdim=True).reshape(
-1, 1, features_size[2], features_size[3]
)
return output, preds, hidden_predict, cache, weights
def generate_enc_mask(self, batch_reduced_size, total_size, device):
......
......@@ -1036,7 +1036,9 @@ class Manager(OCRManager):
enhanced_features = torch.flatten(
enhanced_features, start_dim=2, end_dim=3
).permute(2, 0, 1)
output, pred, hidden_predict, cache, weights = self.models["decoder"](
output, pred, hidden_predict, cache, weights, temperature = self.models[
"decoder"
](
features,
enhanced_features,
simulated_y_pred[:, :-1],
......@@ -1133,7 +1135,9 @@ class Manager(OCRManager):
).permute(2, 0, 1)
for i in range(0, max_chars):
output, pred, hidden_predict, cache, weights = self.models["decoder"](
output, pred, hidden_predict, cache, weights, temperature = self.models[
"decoder"
](
features,
enhanced_features,
predicted_tokens,
......
......@@ -167,6 +167,7 @@ def get_config():
"use_2d_pe": True, # use 2D positional embedding
"use_1d_pe": True, # use 1D positional embedding
"use_lstm": False,
"temperature": 1.0, # temperature scaling scalar parameter
"attention_win": 100, # length of attention window
# Curriculum dropout
"dropout_scheduler": {
......
......@@ -55,6 +55,13 @@ def add_predict_parser(subcommands) -> None:
required=False,
help="Image scaling factor before feeding it to DAN",
)
parser.add_argument(
"--max_width",
type=int,
default=1800,
required=False,
help="Image resizing before feeding it to DAN",
)
parser.add_argument(
"--confidence-score",
action="store_true",
......
......@@ -20,22 +20,23 @@ from dan.predict.attention import (
plot_attention,
split_text_and_confidences,
)
from dan.utils import read_image
from dan.utils import pairwise, read_image
class DAN:
"""
The DAN class is used to apply a DAN model.
The class initializes useful parameters: the device.
The class initializes useful parameters: the device and the temperature scalara parameter.
"""
def __init__(self, device):
def __init__(self, device, temperature=1.0):
"""
Constructor of the DAN class.
:param device: The device to use.
"""
super(DAN, self).__init__()
self.device = device
self.temperature = temperature
def load(self, model_path, params_path, charset_path, mode="eval"):
"""
......@@ -104,6 +105,7 @@ class DAN:
start_token=None,
threshold_method="otsu",
threshold_value=0,
temperature=1.0,
):
"""
Run prediction on an input image.
......@@ -158,7 +160,13 @@ class DAN:
).permute(2, 0, 1)
for i in range(0, self.max_chars):
output, pred, hidden_predict, cache, weights = self.decoder(
(
output,
pred,
hidden_predict,
cache,
weights,
) = self.decoder(
features,
enhanced_features,
predicted_tokens,
......@@ -170,6 +178,8 @@ class DAN:
cache=cache,
num_pred=1,
)
pred = pred / temperature
whole_output.append(output)
attention_maps.append(weights)
confidence_scores.append(
......@@ -256,6 +266,8 @@ def run(
attention_map_scale,
word_separators,
line_separators,
temperature,
image_max_width,
predict_objects,
threshold_method,
threshold_value,
......@@ -274,6 +286,7 @@ def run(
:param attention_map_scale: Scaling factor for the attention map.
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param image_max_width: Resize image
:param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
......@@ -284,11 +297,17 @@ def run(
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device)
dan_model = DAN(device, temperature)
dan_model.load(model, parameters, charset, mode="eval")
# Load image and pre-process it
im = read_image(image, scale=scale)
# Load image and pre-process it
if image_max_width:
h, w, c = read_image(image, scale=1).shape
ratio = image_max_width / w
im = read_image(image, ratio)
logger.info("Image loaded.")
im_p = dan_model.preprocess(im)
logger.debug("Image pre-processed.")
......@@ -326,9 +345,24 @@ def run(
# Return mean confidence score
if confidence_score:
result["confidences"] = {}
char_confidences = prediction["confidences"][0]
text = result["text"]
# retrieve the index of the token ner
index = [pos for pos, char in enumerate(text) if char in ["", "", "", ""]]
# calculates scores by token
score_by_token = [
{
"text": f"{text[current: next_token-1]}",
"confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token-1]), 2)}",
}
for current, next_token in pairwise(index)
]
result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
result["confidences"]["by ner token"] = []
for entity in score_by_token:
result["confidences"]["by ner token"].append(entity)
for level in confidence_score_levels:
result["confidences"][level] = []
......
# -*- coding: utf-8 -*-
from itertools import tee
import cv2
import numpy as np
import torch
......@@ -206,3 +208,13 @@ def round_floats(float_list, decimals=2):
Round list of floats with fixed decimals
"""
return [np.around(num, decimals) for num in float_list]
def pairwise(iterable):
"""
Not necessary when using 3.10. See https://docs.python.org/3/library/itertools.html#itertools.pairwise.
"""
# pairwise('ABCDEFG') --> AB BC CD DE EF FG
a, b = tee(iterable)
next(b, None)
return zip(a, b)
......@@ -107,6 +107,7 @@ def training_config():
"dec_pred_dropout": 0.1, # dropout rate before decision layer
"dec_att_dropout": 0.1, # dropout rate in multi head attention
"dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers
"temperature": 1.0, # temperature scaling scalar parameter
"use_2d_pe": True, # use 2D positional embedding
"use_1d_pe": True, # use 1D positional embedding
"use_lstm": False,
......
......@@ -22,3 +22,4 @@ parameters:
dec_num_heads: 4
dec_att_dropout: 0.1
dec_res_dropout: 0.1
temperature: 1.0