Skip to content
Snippets Groups Projects
prediction.py 10.73 KiB
# -*- coding: utf-8 -*-

import os
import pickle
import re

import cv2
import numpy as np
import torch
import yaml

from dan import logger
from dan.datasets.extract.utils import save_json
from dan.decoder import GlobalHTADecoder
from dan.models import FCN_Encoder
from dan.ocr.utils import LM_ind_to_str
from dan.predict.attention import plot_attention
from dan.utils import read_image, round_floats


class DAN:
    """
    The DAN class is used to apply a DAN model.
    The class initializes useful parameters: the device.
    """

    def __init__(self, device):
        """
        Constructor of the DAN class.
        :param device: The device to use.
        """
        super(DAN, self).__init__()
        self.device = device

    def load(self, model_path, params_path, charset_path, mode="eval"):
        """
        Load a trained model.
        :param model_path: Path to the model.
        :param params_path: Path to the parameters.
        :param charset_path: Path to the charset.
        :param mode: The mode to load the model (train or eval).
        """
        with open(params_path, "r") as f:
            parameters = yaml.safe_load(f)["parameters"]
            parameters["decoder"]["device"] = self.device

        with open(charset_path, "rb") as f:
            self.charset = pickle.load(f)

        # Restore the model weights.
        checkpoint = torch.load(model_path, map_location=self.device)

        encoder = FCN_Encoder(parameters["encoder"]).to(self.device)
        encoder.load_state_dict(checkpoint["encoder_state_dict"], strict=True)

        decoder = GlobalHTADecoder(parameters["decoder"]).to(self.device)
        decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=True)
        logger.debug(f"Loaded model {model_path}")

        if mode == "train":
            encoder.train()
            decoder.train()
        elif mode == "eval":
            encoder.eval()
            decoder.eval()
        else:
            raise Exception("Unsupported mode")

        self.encoder = encoder
        self.decoder = decoder
        self.mean, self.std = parameters["mean"], parameters["std"]
        self.max_chars = parameters["max_char_prediction"]

    def preprocess(self, input_image):
        """
        Preprocess an input_image.
        :param input_image: The input image to preprocess.
        """
        assert isinstance(
            input_image, np.ndarray
        ), "Input image must be an np.array in RGB"
        input_image = np.asarray(input_image)
        if len(input_image.shape) < 3:
            input_image = cv2.cvtColor(input_image, cv2.COLOR_GRAY2RGB)

        input_image = (input_image - self.mean) / self.std
        return input_image

    def predict(
        self,
        input_tensor,
        input_sizes,
        confidences=False,
        attentions=False,
        start_token=None,
    ):
        """
        Run prediction on an input image.
        :param input_tensor: A batch of images to predict.
        :param input_sizes: The original images sizes.
        :param confidences: Return the characters probabilities.
        :param attentions: Return characters attention weights.
        """
        input_tensor = input_tensor.to(self.device)

        start_token = (
            self.charset.index(start_token) if start_token else len(self.charset) + 1
        )
        end_token = len(self.charset)

        # Run the prediction.
        with torch.no_grad():
            b = input_tensor.size(0)
            reached_end = torch.zeros((b,), dtype=torch.bool, device=self.device)
            prediction_len = torch.zeros((b,), dtype=torch.int, device=self.device)
            predicted_tokens = (
                torch.ones((b, 1), dtype=torch.long, device=self.device) * start_token
            )
            predicted_tokens_len = torch.ones((b,), dtype=torch.int, device=self.device)

            whole_output = list()
            confidence_scores = list()
            attention_maps = list()
            cache = None
            hidden_predict = None

            features = self.encoder(input_tensor.float())
            features_size = features.size()
            coverage_vector = torch.zeros(
                (features.size(0), 1, features.size(2), features.size(3)),
                device=self.device,
            )
            pos_features = self.decoder.features_updater.get_pos_features(features)
            features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
                2, 0, 1
            )
            enhanced_features = pos_features
            enhanced_features = torch.flatten(
                enhanced_features, start_dim=2, end_dim=3
            ).permute(2, 0, 1)

            for i in range(0, self.max_chars):
                output, pred, hidden_predict, cache, weights = self.decoder(
                    features,
                    enhanced_features,
                    predicted_tokens,
                    input_sizes,
                    predicted_tokens_len,
                    features_size,
                    start=0,
                    hidden_predict=hidden_predict,
                    cache=cache,
                    num_pred=1,
                )
                whole_output.append(output)
                attention_maps.append(weights)
                confidence_scores.append(
                    torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
                )
                coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
                predicted_tokens = torch.cat(
                    [
                        predicted_tokens,
                        torch.argmax(pred[:, :, -1], dim=1, keepdim=True),
                    ],
                    dim=1,
                )
                reached_end = torch.logical_or(
                    reached_end, torch.eq(predicted_tokens[:, -1], end_token)
                )
                predicted_tokens_len += 1

                prediction_len[reached_end == False] = i + 1  # noqa E712

                if torch.all(reached_end):
                    break

            # Concatenate tensors for each token
            confidence_scores = (
                torch.cat(confidence_scores, dim=1).cpu().detach().numpy()
            )
            attention_maps = torch.cat(attention_maps, dim=1).cpu().detach().numpy()

            # Remove bot and eot tokens
            predicted_tokens = predicted_tokens[:, 1:]
            prediction_len[torch.eq(reached_end, False)] = self.max_chars - 1
            predicted_tokens = [
                predicted_tokens[i, : prediction_len[i]] for i in range(b)
            ]
            confidence_scores = [
                confidence_scores[i, : prediction_len[i]].tolist() for i in range(b)
            ]

            # Transform tokens to characters
            predicted_text = [
                LM_ind_to_str(self.charset, t, oov_symbol="") for t in predicted_tokens
            ]

            logger.info("Images processed")

        out = {"text": predicted_text}
        if confidences:
            out["confidences"] = confidence_scores
        if attentions:
            out["attentions"] = attention_maps
        return out


def parse_delimiters(delimiters):
    return re.compile(r"|".join(delimiters))


def compute_prob_by_separator(characters, probabilities, separator):
    """
    Split text and confidences using separators and return a list of average confidence scores.
    :param characters: list of characters.
    :param probabilities: list of probabilities.
    :param separators: regex for separators. Use parse_delimiters(["\n", " "]) for word confidences and parse_delimiters(["\n"]) for line confidences.
    Returns a list confidence scores.
    """
    # match anything except separators, get start and end index
    pattern = re.compile(f"[^{separator.pattern}]+")
    matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)]

    # Iterate over text pieces and compute mean confidence
    return [np.mean(probabilities[start:end]) for (start, end) in matches]


def run(
    image,
    model,
    parameters,
    charset,
    output,
    scale,
    confidence_score,
    confidence_score_levels,
    attention_map,
    attention_map_level,
    attention_map_scale,
    word_separators,
    line_separators,
):
    # Create output directory if necessary
    if not os.path.exists(output):
        os.mkdir(output)

    # Load model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dan_model = DAN(device)
    dan_model.load(model, parameters, charset, mode="eval")

    # Load image and pre-process it
    im = read_image(image, scale=scale)
    logger.info("Image loaded.")
    im_p = dan_model.preprocess(im)
    logger.debug("Image pre-processed.")

    # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
    input_tensor = torch.tensor(im_p).permute(2, 0, 1).unsqueeze(0)
    input_tensor = input_tensor.to(device)
    input_sizes = [im.shape[:2]]

    # Predict
    prediction = dan_model.predict(
        input_tensor,
        input_sizes,
        confidences=confidence_score,
        attentions=attention_map,
    )
    text = prediction["text"][0]
    result = {"text": text}

    # Parse delimiters to regex
    word_separators = parse_delimiters(word_separators)
    line_separators = parse_delimiters(line_separators)

    # Average character-based confidence scores
    if confidence_score:
        char_confidences = prediction["confidences"][0]
        result["confidences"] = {"total": np.around(np.mean(char_confidences), 2)}
        if "word" in confidence_score_levels:
            word_probs = compute_prob_by_separator(
                text, char_confidences, word_separators
            )
            result["confidences"].update({"word": round_floats(word_probs)})
        if "line" in confidence_score_levels:
            line_probs = compute_prob_by_separator(
                text, char_confidences, line_separators
            )
            result["confidences"].update({"line": round_floats(line_probs)})
        if "char" in confidence_score_levels:
            result["confidences"].update({"char": round_floats(char_confidences)})

    # Save gif with attention map
    if attention_map:
        gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif"
        logger.info(f"Creating attention GIF in {gif_filename}")
        plot_attention(
            image=im,
            text=prediction["text"][0],
            weights=prediction["attentions"][0],
            level=attention_map_level,
            scale=attention_map_scale,
            word_separators=word_separators,
            line_separators=line_separators,
            outname=gif_filename,
        )
        result["attention_gif"] = gif_filename

    json_filename = f"{output}/{image.stem}.json"
    logger.info(f"Saving JSON prediction in {json_filename}")
    save_json(json_filename, result)