Skip to content
Snippets Groups Projects
Commit e0df5532 authored by Marie Generali's avatar Marie Generali :worried:
Browse files

same code but after pre-commit

parent 1fb0d691
No related branches found
No related tags found
No related merge requests found
......@@ -382,7 +382,7 @@ class GlobalHTADecoder(Module):
cache=None,
num_pred=None,
keep_all_weights=False,
temperature = 1,
temperature=1,
token_line=None,
token_pg=None,
):
......@@ -454,9 +454,9 @@ class GlobalHTADecoder(Module):
weights = torch.sum(weights, dim=1, keepdim=True).reshape(
-1, 1, features_size[2], features_size[3]
)
temperature = self.temperature
return output, preds, hidden_predict, cache, weights, temperature
def generate_enc_mask(self, batch_reduced_size, total_size, device):
......
......@@ -3,7 +3,6 @@
import os
import pickle
from pathlib import Path
from typing import DefaultDict
import cv2
import numpy as np
......@@ -159,7 +158,14 @@ class DAN:
).permute(2, 0, 1)
for i in range(0, self.max_chars):
output, pred, hidden_predict, cache, weights, temperature = self.decoder(
(
output,
pred,
hidden_predict,
cache,
weights,
temperature,
) = self.decoder(
features,
enhanced_features,
predicted_tokens,
......@@ -170,9 +176,9 @@ class DAN:
hidden_predict=hidden_predict,
cache=cache,
num_pred=1,
temperature = 1,
temperature=1,
)
pred = pred / temperature
pred = pred / temperature
whole_output.append(output)
attention_maps.append(weights)
confidence_scores.append(
......@@ -199,7 +205,6 @@ class DAN:
confidence_scores = (
torch.cat(confidence_scores, dim=1).cpu().detach().numpy()
)
token_confidence_scores = confidence_scores
attention_maps = torch.cat(attention_maps, dim=1).cpu().detach().numpy()
# Remove bot and eot tokens
......@@ -293,7 +298,7 @@ def run(
# Load image and pre-process it
im = read_image(image, scale=scale)
h, w, c = read_image(image, scale= 1).shape
h, w, c = read_image(image, scale=1).shape
ratio = 1800 / w
im = read_image(image, ratio)
logger.info("Image loaded.")
......@@ -334,22 +339,31 @@ def run(
if confidence_score:
result["confidences"] = {}
char_confidences = prediction["confidences"][0]
text = result['text']
#retrieve the index of the token ner
index = [pos for pos, char in enumerate(text) if char in ["", "", "" , ""]]
text = result["text"]
# retrieve the index of the token ner
index = [pos for pos, char in enumerate(text) if char in ["", "", "", ""]]
# calculates scores by token
score_by_token = []
for rang, position in enumerate(index[:-1]):
score_by_token.append({'text':f'{text[position: index[rang+1]-1]}', 'confidence_ner' : f'{np.around(np.mean(char_confidences[position : index[rang+1]-1]), 2)}'})
score_by_token.append({'text':f'{text[index[-2]: index[-1]]}', 'confidence_ner' : f'{np.around(np.mean(char_confidences[index[-2] : index[-1]]), 2)}'})
score_by_token.append(
{
"text": f"{text[position: index[rang+1]-1]}",
"confidence_ner": f"{np.around(np.mean(char_confidences[position : index[rang+1]-1]), 2)}",
}
)
result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
result['confidences']['by ner token']=[]
for entity in score_by_token:
result["confidences"]['by ner token'].append(entity)
score_by_token.append(
{
"text": f"{text[index[-2]: index[-1]]}",
"confidence_ner": f"{np.around(np.mean(char_confidences[index[-2] : index[-1]]), 2)}",
}
)
result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
result["confidences"]["by ner token"] = []
for entity in score_by_token:
result["confidences"]["by ner token"].append(entity)
for level in confidence_score_levels:
result["confidences"][level] = []
......
......@@ -202,7 +202,6 @@ def read_image(filename, scale=1.0):
return image
def round_floats(float_list, decimals=2):
"""
Round list of floats with fixed decimals
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment