Skip to content
Snippets Groups Projects
Commit 1fb0d691 authored by Marie Generali's avatar Marie Generali :worried:
Browse files

implement temperature scaling on dan

parent 378a86c0
No related branches found
No related tags found
No related merge requests found
......@@ -367,6 +367,7 @@ class GlobalHTADecoder(Module):
vocab_size = params["vocab_size"] + 1
self.end_conv = Conv1d(self.enc_dim, vocab_size, kernel_size=1)
self.temperature = params["temperature"]
def forward(
self,
......@@ -381,6 +382,7 @@ class GlobalHTADecoder(Module):
cache=None,
num_pred=None,
keep_all_weights=False,
temperature = 1,
token_line=None,
token_pg=None,
):
......@@ -452,7 +454,10 @@ class GlobalHTADecoder(Module):
weights = torch.sum(weights, dim=1, keepdim=True).reshape(
-1, 1, features_size[2], features_size[3]
)
return output, preds, hidden_predict, cache, weights
temperature = self.temperature
return output, preds, hidden_predict, cache, weights, temperature
def generate_enc_mask(self, batch_reduced_size, total_size, device):
"""
......
......@@ -3,6 +3,7 @@
import os
import pickle
from pathlib import Path
from typing import DefaultDict
import cv2
import numpy as np
......@@ -158,7 +159,7 @@ class DAN:
).permute(2, 0, 1)
for i in range(0, self.max_chars):
output, pred, hidden_predict, cache, weights = self.decoder(
output, pred, hidden_predict, cache, weights, temperature = self.decoder(
features,
enhanced_features,
predicted_tokens,
......@@ -169,7 +170,9 @@ class DAN:
hidden_predict=hidden_predict,
cache=cache,
num_pred=1,
temperature = 1,
)
pred = pred / temperature
whole_output.append(output)
attention_maps.append(weights)
confidence_scores.append(
......@@ -196,6 +199,7 @@ class DAN:
confidence_scores = (
torch.cat(confidence_scores, dim=1).cpu().detach().numpy()
)
token_confidence_scores = confidence_scores
attention_maps = torch.cat(attention_maps, dim=1).cpu().detach().numpy()
# Remove bot and eot tokens
......@@ -289,6 +293,9 @@ def run(
# Load image and pre-process it
im = read_image(image, scale=scale)
h, w, c = read_image(image, scale= 1).shape
ratio = 1800 / w
im = read_image(image, ratio)
logger.info("Image loaded.")
im_p = dan_model.preprocess(im)
logger.debug("Image pre-processed.")
......@@ -326,9 +333,23 @@ def run(
# Return mean confidence score
if confidence_score:
result["confidences"] = {}
char_confidences = prediction["confidences"][0]
text = result['text']
#retrieve the index of the token ner
index = [pos for pos, char in enumerate(text) if char in ["", "", "" , ""]]
# calculates scores by token
score_by_token = []
for rang, position in enumerate(index[:-1]):
score_by_token.append({'text':f'{text[position: index[rang+1]-1]}', 'confidence_ner' : f'{np.around(np.mean(char_confidences[position : index[rang+1]-1]), 2)}'})
score_by_token.append({'text':f'{text[index[-2]: index[-1]]}', 'confidence_ner' : f'{np.around(np.mean(char_confidences[index[-2] : index[-1]]), 2)}'})
result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
result['confidences']['by ner token']=[]
for entity in score_by_token:
result["confidences"]['by ner token'].append(entity)
for level in confidence_score_levels:
result["confidences"][level] = []
......
......@@ -198,9 +198,11 @@ def read_image(filename, scale=1.0):
width = int(image.shape[1] * scale)
height = int(image.shape[0] * scale)
image = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
return image
def round_floats(float_list, decimals=2):
"""
Round list of floats with fixed decimals
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment