Something went wrong on our end
-
Manon Blanco authoredManon Blanco authored
inference.py 17.34 KiB
# -*- coding: utf-8 -*-
import json
import logging
import pickle
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import yaml
from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.predict.attention import (
Level,
get_predicted_polygons_with_confidence,
parse_delimiters,
plot_attention,
split_text_and_confidences,
)
from dan.ocr.transforms import get_preprocessing_transforms
from dan.utils import (
EntityType,
ind_to_token,
list_to_batches,
pad_images,
read_image,
)
logger = logging.getLogger(__name__)
class DAN:
"""
The DAN class is used to apply a DAN model.
The class initializes useful parameters: the device and the temperature scalar parameter.
"""
def __init__(self, device: str, temperature=1.0) -> None:
"""
Constructor of the DAN class.
:param device: The device to use.
"""
super(DAN, self).__init__()
self.device = device
self.temperature = temperature
def load(
self,
model_path: Path,
params_path: Path,
charset_path: Path,
mode: str = "eval",
use_language_model: bool = False,
) -> None:
"""
Load a trained model.
:param model_path: Path to the model.
:param params_path: Path to the parameters.
:param charset_path: Path to the charset.
:param mode: The mode to load the model (train or eval).
:param use_language_model: Whether to use an explicit language model to rescore text hypotheses.
"""
parameters = yaml.safe_load(params_path.read_text())["parameters"]
parameters["decoder"]["device"] = self.device
self.charset = pickle.loads(charset_path.read_bytes())
# Restore the model weights.
checkpoint = torch.load(model_path, map_location=self.device)
encoder = FCN_Encoder(parameters["encoder"]).to(self.device)
encoder.load_state_dict(checkpoint["encoder_state_dict"], strict=True)
decoder = GlobalHTADecoder(parameters["decoder"]).to(self.device)
decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=True)
logger.debug(f"Loaded model {model_path}")
if mode == "train":
encoder.train()
decoder.train()
elif mode == "eval":
encoder.eval()
decoder.eval()
else:
raise Exception("Unsupported mode")
self.encoder = encoder
self.decoder = decoder
self.lm_decoder = None
if use_language_model and parameters["language_model"]["weight"] > 0:
logger.info(
f"Decoding with a language model (weight={parameters['language_model']['weight']})."
)
self.lm_decoder = CTCLanguageDecoder(
language_model_path=parameters["language_model"]["model"],
lexicon_path=parameters["language_model"]["lexicon"],
tokens_path=parameters["language_model"]["tokens"],
language_model_weight=parameters["language_model"]["weight"],
)
self.mean, self.std = (
torch.tensor(parameters["mean"]) / 255,
torch.tensor(parameters["std"]) / 255,
)
self.preprocessing_transforms = get_preprocessing_transforms(
parameters.get("preprocessings", [])
)
self.max_chars = parameters["max_char_prediction"]
def preprocess(self, path: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Preprocess an image.
:param path: Path of the image to load and preprocess.
"""
image = read_image(path)
preprocessed_image = self.preprocessing_transforms(image)
normalized_image = torch.zeros(preprocessed_image.shape)
for ch in range(preprocessed_image.shape[0]):
normalized_image[ch, :, :] = (
preprocessed_image[ch, :, :] - self.mean[ch]
) / self.std[ch]
return preprocessed_image, normalized_image
def predict(
self,
input_tensor: torch.Tensor,
input_sizes: List[torch.Size],
confidences: bool = False,
attentions: bool = False,
attention_level: Level = Level.Line,
extract_objects: bool = False,
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
tokens: Dict[str, EntityType] = {},
start_token: str = None,
max_object_height: int = 50,
use_language_model: bool = False,
) -> dict:
"""
Run prediction on an input image.
:param input_tensor: A batch of images to predict.
:param input_sizes: The original images sizes.
:param confidences: Return the characters probabilities.
:param attentions: Return characters attention weights.
:param attention_level: Level of text pieces (must be in [char, word, line, ner])
:param extract_objects: Whether to extract polygons' coordinates.
:param max_object_height: Maximum height of predicted objects.
"""
input_tensor = input_tensor.to(self.device)
start_token = (
self.charset.index(start_token) if start_token else len(self.charset) + 1
)
end_token = len(self.charset)
# Run the prediction.
with torch.no_grad():
batch_size = input_tensor.size(0)
reached_end = torch.zeros(
(batch_size,), dtype=torch.bool, device=self.device
)
prediction_len = torch.zeros(
(batch_size,), dtype=torch.int, device=self.device
)
predicted_tokens = (
torch.ones((batch_size, 1), dtype=torch.long, device=self.device)
* start_token
)
predicted_tokens_len = torch.ones(
(batch_size,), dtype=torch.int, device=self.device
)
# end token index will be used for ctc
tot_pred = torch.zeros(
(batch_size, len(self.charset) + 1, self.max_chars),
dtype=torch.float,
device=self.device,
)
whole_output = list()
confidence_scores = list()
attention_maps = list()
cache = None
hidden_predict = None
features = self.encoder(input_tensor.float())
features_size = features.size()
features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
for i in range(0, self.max_chars):
(
output,
pred,
hidden_predict,
cache,
weights,
) = self.decoder(
features,
predicted_tokens,
input_sizes,
predicted_tokens_len,
features_size,
start=0,
hidden_predict=hidden_predict,
cache=cache,
num_pred=1,
)
# output total logit prediction
tot_pred[:, :, i : i + 1] = pred
pred = pred / self.temperature
whole_output.append(output)
attention_maps.append(weights)
confidence_scores.append(
torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
)
predicted_tokens = torch.cat(
[
predicted_tokens,
torch.argmax(pred[:, :, -1], dim=1, keepdim=True),
],
dim=1,
)
reached_end = torch.logical_or(
reached_end, torch.eq(predicted_tokens[:, -1], end_token)
)
predicted_tokens_len += 1
prediction_len[reached_end == False] = i + 1 # noqa E712
if torch.all(reached_end):
break
# Concatenate tensors for each token
confidence_scores = (
torch.cat(confidence_scores, dim=1).cpu().detach().numpy()
)
attention_maps = torch.cat(attention_maps, dim=1).cpu().detach().numpy()
# Remove bot and eot tokens
predicted_tokens = predicted_tokens[:, 1:]
prediction_len[torch.eq(reached_end, False)] = self.max_chars - 1
predicted_tokens = [
predicted_tokens[i, : prediction_len[i]] for i in range(batch_size)
]
confidence_scores = [
confidence_scores[i, : prediction_len[i]].tolist()
for i in range(batch_size)
]
# Transform tokens to characters
predicted_text = [
ind_to_token(self.charset, t, oov_symbol="") for t in predicted_tokens
]
logger.info("Images processed")
out = {}
out["text"] = predicted_text
if use_language_model:
out["language_model"] = self.lm_decoder(tot_pred, prediction_len)
if confidences:
out["confidences"] = confidence_scores
if attentions:
out["attentions"] = attention_maps
if extract_objects:
out["objects"] = [
get_predicted_polygons_with_confidence(
predicted_text[i],
attention_maps[i],
confidence_scores[i],
attention_level,
input_sizes[i][0],
input_sizes[i][1],
max_object_height=max_object_height,
word_separators=word_separators,
line_separators=line_separators,
tokens=tokens,
)
for i in range(batch_size)
]
return out
def process_batch(
image_batch: List[Path],
dan_model: DAN,
device: str,
output: Path,
confidence_score: bool,
confidence_score_levels: List[Level],
attention_map: bool,
attention_map_level: Level,
attention_map_scale: float,
word_separators: List[str],
line_separators: List[str],
predict_objects: bool,
max_object_height: int,
tokens: Dict[str, EntityType],
start_token: str,
use_language_model: bool,
) -> None:
input_images, visu_images, input_sizes = [], [], []
logger.info("Loading images...")
for image_path in image_batch:
# Load image and pre-process it
visu_image, input_image = dan_model.preprocess(str(image_path))
input_images.append(input_image)
visu_images.append(visu_image)
input_sizes.append(input_image.shape[1:])
# Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
input_tensor = pad_images(input_images).to(device)
visu_tensor = pad_images(visu_images).to(device)
logger.info("Images preprocessed!")
# Parse delimiters to regex
word_separators = parse_delimiters(word_separators)
line_separators = parse_delimiters(line_separators)
# Predict
logger.info("Predicting...")
prediction = dan_model.predict(
input_tensor,
input_sizes,
confidences=confidence_score,
attentions=attention_map,
attention_level=attention_map_level,
extract_objects=predict_objects,
word_separators=word_separators,
line_separators=line_separators,
tokens=tokens,
max_object_height=max_object_height,
start_token=start_token,
use_language_model=use_language_model,
)
logger.info("Prediction parsing...")
for idx, image_path in enumerate(image_batch):
predicted_text = prediction["text"][idx]
result = {"text": predicted_text, "confidences": {}, "language_model": {}}
if predicted_text:
# Return LM results
if use_language_model:
result["language_model"] = {
"text": prediction["language_model"]["text"][idx],
"confidence": prediction["language_model"]["confidence"][idx],
}
# Return extracted objects (coordinates, text, confidence)
if predict_objects:
result["objects"] = prediction["objects"][idx]
# Return mean confidence score
if confidence_score:
char_confidences = prediction["confidences"][idx]
result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
for level in confidence_score_levels:
result["confidences"][level.value] = []
texts, confidences, _ = split_text_and_confidences(
predicted_text,
char_confidences,
level,
word_separators,
line_separators,
tokens,
)
for text, conf in zip(texts, confidences):
result["confidences"][level.value].append(
{"text": text, "confidence": conf}
)
# Save gif with attention map
if attention_map:
attentions = prediction["attentions"][idx]
gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif"
logger.info(f"Creating attention GIF in {gif_filename}")
plot_attention(
image=visu_tensor[idx],
text=predicted_text,
weights=attentions,
level=attention_map_level,
scale=attention_map_scale,
word_separators=word_separators,
line_separators=line_separators,
tokens=tokens,
display_polygons=predict_objects,
max_object_height=max_object_height,
outname=gif_filename,
)
result["attention_gif"] = gif_filename
json_filename = Path(output, f"{image_path.stem}.json")
logger.info(f"Saving JSON prediction in {json_filename}")
json_filename.write_text(json.dumps(result, indent=2))
def run(
image: Optional[Path],
image_dir: Optional[Path],
model: Path,
parameters: Path,
charset: Path,
output: Path,
confidence_score: bool,
confidence_score_levels: List[Level],
attention_map: bool,
attention_map_level: Level,
attention_map_scale: float,
word_separators: List[str],
line_separators: List[str],
temperature: float,
predict_objects: bool,
max_object_height: int,
image_extension: str,
gpu_device: int,
batch_size: int,
tokens: Dict[str, EntityType],
start_token: str,
use_language_model: bool,
) -> None:
"""
Predict a single image save the output
:param image: Path to the image to predict.
:param image_dir: Path to the folder where the images to predict are stored.
:param model: Path to the model to use for prediction.
:param parameters: Path to the YAML parameters file.
:param charset: Path to the charset.
:param output: Path to the output folder where the results will be saved.
:param confidence_score: Whether to compute confidence score.
:param attention_map: Whether to plot the attention map.
:param attention_map_level: Level of objects to extract.
:param attention_map_scale: Scaling factor for the attention map.
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param predict_objects: Whether to extract objects.
:param max_object_height: Maximum height of predicted objects.
:param gpu_device: Use a specific GPU if available.
:param batch_size: Size of the batches for prediction.
:param tokens: NER tokens used.
:param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.
:param use_language_model: Whether to use an explicit language model to rescore text hypotheses.
"""
# Create output directory if necessary
if not output.exists():
output.mkdir()
# Load model
cuda_device = f":{gpu_device}" if gpu_device is not None else ""
device = f"cuda{cuda_device}" if torch.cuda.is_available() else "cpu"
dan_model = DAN(device, temperature)
dan_model.load(
model, parameters, charset, mode="eval", use_language_model=use_language_model
)
# Do not use LM with invalid LM weight
use_language_model = dan_model.lm_decoder is not None
images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_batch in list_to_batches(images, n=batch_size):
process_batch(
image_batch,
dan_model,
device,
output,
confidence_score,
confidence_score_levels,
attention_map,
attention_map_level,
attention_map_scale,
word_separators,
line_separators,
predict_objects,
max_object_height,
tokens,
start_token,
use_language_model,
)