Something went wrong on our end
prediction.py 10.73 KiB
# -*- coding: utf-8 -*-
import os
import pickle
import re
import cv2
import numpy as np
import torch
import yaml
from dan import logger
from dan.datasets.extract.utils import save_json
from dan.decoder import GlobalHTADecoder
from dan.models import FCN_Encoder
from dan.ocr.utils import LM_ind_to_str
from dan.predict.attention import plot_attention
from dan.utils import read_image, round_floats
class DAN:
"""
The DAN class is used to apply a DAN model.
The class initializes useful parameters: the device.
"""
def __init__(self, device):
"""
Constructor of the DAN class.
:param device: The device to use.
"""
super(DAN, self).__init__()
self.device = device
def load(self, model_path, params_path, charset_path, mode="eval"):
"""
Load a trained model.
:param model_path: Path to the model.
:param params_path: Path to the parameters.
:param charset_path: Path to the charset.
:param mode: The mode to load the model (train or eval).
"""
with open(params_path, "r") as f:
parameters = yaml.safe_load(f)["parameters"]
parameters["decoder"]["device"] = self.device
with open(charset_path, "rb") as f:
self.charset = pickle.load(f)
# Restore the model weights.
checkpoint = torch.load(model_path, map_location=self.device)
encoder = FCN_Encoder(parameters["encoder"]).to(self.device)
encoder.load_state_dict(checkpoint["encoder_state_dict"], strict=True)
decoder = GlobalHTADecoder(parameters["decoder"]).to(self.device)
decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=True)
logger.debug(f"Loaded model {model_path}")
if mode == "train":
encoder.train()
decoder.train()
elif mode == "eval":
encoder.eval()
decoder.eval()
else:
raise Exception("Unsupported mode")
self.encoder = encoder
self.decoder = decoder
self.mean, self.std = parameters["mean"], parameters["std"]
self.max_chars = parameters["max_char_prediction"]
def preprocess(self, input_image):
"""
Preprocess an input_image.
:param input_image: The input image to preprocess.
"""
assert isinstance(
input_image, np.ndarray
), "Input image must be an np.array in RGB"
input_image = np.asarray(input_image)
if len(input_image.shape) < 3:
input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB)
input_image = (input_image - self.mean) / self.std
return input_image
def predict(
self,
input_tensor,
input_sizes,
confidences=False,
attentions=False,
start_token=None,
):
"""
Run prediction on an input image.
:param input_tensor: A batch of images to predict.
:param input_sizes: The original images sizes.
:param confidences: Return the characters probabilities.
:param attentions: Return characters attention weights.
"""
input_tensor = input_tensor.to(self.device)
start_token = (
self.charset.index(start_token) if start_token else len(self.charset) + 1
)
end_token = len(self.charset)
# Run the prediction.
with torch.no_grad():
b = input_tensor.size(0)
reached_end = torch.zeros((b,), dtype=torch.bool, device=self.device)
prediction_len = torch.zeros((b,), dtype=torch.int, device=self.device)
predicted_tokens = (
torch.ones((b, 1), dtype=torch.long, device=self.device) * start_token
)
predicted_tokens_len = torch.ones((b,), dtype=torch.int, device=self.device)
whole_output = list()
confidence_scores = list()
attention_maps = list()
cache = None
hidden_predict = None
features = self.encoder(input_tensor.float())
features_size = features.size()
coverage_vector = torch.zeros(
(features.size(0), 1, features.size(2), features.size(3)),
device=self.device,
)
pos_features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
2, 0, 1
)
enhanced_features = pos_features
enhanced_features = torch.flatten(
enhanced_features, start_dim=2, end_dim=3
).permute(2, 0, 1)
for i in range(0, self.max_chars):
output, pred, hidden_predict, cache, weights = self.decoder(
features,
enhanced_features,
predicted_tokens,
input_sizes,
predicted_tokens_len,
features_size,
start=0,
hidden_predict=hidden_predict,
cache=cache,
num_pred=1,
)
whole_output.append(output)
attention_maps.append(weights)
confidence_scores.append(
torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
)
coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
predicted_tokens = torch.cat(
[
predicted_tokens,
torch.argmax(pred[:, :, -1], dim=1, keepdim=True),
],
dim=1,
)
reached_end = torch.logical_or(
reached_end, torch.eq(predicted_tokens[:, -1], end_token)
)
predicted_tokens_len += 1
prediction_len[reached_end == False] = i + 1 # noqa E712
if torch.all(reached_end):
break
# Concatenate tensors for each token
confidence_scores = (
torch.cat(confidence_scores, dim=1).cpu().detach().numpy()
)
attention_maps = torch.cat(attention_maps, dim=1).cpu().detach().numpy()
# Remove bot and eot tokens
predicted_tokens = predicted_tokens[:, 1:]
prediction_len[torch.eq(reached_end, False)] = self.max_chars - 1
predicted_tokens = [
predicted_tokens[i, : prediction_len[i]] for i in range(b)
]
confidence_scores = [
confidence_scores[i, : prediction_len[i]].tolist() for i in range(b)
]
# Transform tokens to characters
predicted_text = [
LM_ind_to_str(self.charset, t, oov_symbol="") for t in predicted_tokens
]
logger.info("Images processed")
out = {"text": predicted_text}
if confidences:
out["confidences"] = confidence_scores
if attentions:
out["attentions"] = attention_maps
return out
def parse_delimiters(delimiters):
return re.compile(r"|".join(delimiters))
def compute_prob_by_separator(characters, probabilities, separator):
"""
Split text and confidences using separators and return a list of average confidence scores.
:param characters: list of characters.
:param probabilities: list of probabilities.
:param separators: regex for separators. Use parse_delimiters(["\n", " "]) for word confidences and parse_delimiters(["\n"]) for line confidences.
Returns a list confidence scores.
"""
# match anything except separators, get start and end index
pattern = re.compile(f"[^{separator.pattern}]+")
matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)]
# Iterate over text pieces and compute mean confidence
return [np.mean(probabilities[start:end]) for (start, end) in matches]
def run(
image,
model,
parameters,
charset,
output,
scale,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
word_separators,
line_separators,
):
# Create output directory if necessary
if not os.path.exists(output):
os.mkdir(output)
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device)
dan_model.load(model, parameters, charset, mode="eval")
# Load image and pre-process it
im = read_image(image, scale=scale)
logger.info("Image loaded.")
im_p = dan_model.preprocess(im)
logger.debug("Image pre-processed.")
# Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
input_tensor = torch.tensor(im_p).permute(2, 0, 1).unsqueeze(0)
input_tensor = input_tensor.to(device)
input_sizes = [im.shape[:2]]
# Predict
prediction = dan_model.predict(
input_tensor,
input_sizes,
confidences=confidence_score,
attentions=attention_map,
)
text = prediction["text"][0]
result = {"text": text}
# Parse delimiters to regex
word_separators = parse_delimiters(word_separators)
line_separators = parse_delimiters(line_separators)
# Average character-based confidence scores
if confidence_score:
char_confidences = prediction["confidences"][0]
result["confidences"] = {"total": np.around(np.mean(char_confidences), 2)}
if "word" in confidence_score_levels:
word_probs = compute_prob_by_separator(
text, char_confidences, word_separators
)
result["confidences"].update({"word": round_floats(word_probs)})
if "line" in confidence_score_levels:
line_probs = compute_prob_by_separator(
text, char_confidences, line_separators
)
result["confidences"].update({"line": round_floats(line_probs)})
if "char" in confidence_score_levels:
result["confidences"].update({"char": round_floats(char_confidences)})
# Save gif with attention map
if attention_map:
gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif"
logger.info(f"Creating attention GIF in {gif_filename}")
plot_attention(
image=im,
text=prediction["text"][0],
weights=prediction["attentions"][0],
level=attention_map_level,
scale=attention_map_scale,
word_separators=word_separators,
line_separators=line_separators,
outname=gif_filename,
)
result["attention_gif"] = gif_filename
json_filename = f"{output}/{image.stem}.json"
logger.info(f"Saving JSON prediction in {json_filename}")
save_json(json_filename, result)