Skip to content
Snippets Groups Projects

Add predicted objects to predict command

Merged Thibault Lavigne requested to merge 36-add-predicted-objects-to-predict-command into main
All threads resolved!
1 file
+ 3
1
Compare changes
  • Side-by-side
  • Inline
+ 110
50
@@ -2,7 +2,7 @@
import os
import pickle
import re
from pathlib import Path
import cv2
import numpy as np
@@ -14,8 +14,13 @@ from dan.datasets.extract.utils import save_json
from dan.decoder import GlobalHTADecoder
from dan.models import FCN_Encoder
from dan.ocr.utils import LM_ind_to_str
from dan.predict.attention import plot_attention
from dan.utils import read_image, round_floats
from dan.predict.attention import (
get_predicted_polygons_with_confidence,
parse_delimiters,
plot_attention,
split_text_and_confidences,
)
from dan.utils import read_image
class DAN:
@@ -92,7 +97,13 @@ class DAN:
input_sizes,
confidences=False,
attentions=False,
attention_level=False,
extract_objects=False,
word_separators=["\n", " "],
line_separators=["\n"],
start_token=None,
threshold_method="otsu",
threshold_value=0,
):
"""
Run prediction on an input image.
@@ -100,6 +111,10 @@ class DAN:
:param input_sizes: The original images sizes.
:param confidences: Return the characters probabilities.
:param attentions: Return characters attention weights.
:param attention_level: Level of text pieces (must be in [char, word, line])
:param extract_objects: Whether to extract polygons' coordinates.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
"""
input_tensor = input_tensor.to(self.device)
@@ -110,13 +125,20 @@ class DAN:
# Run the prediction.
with torch.no_grad():
b = input_tensor.size(0)
reached_end = torch.zeros((b,), dtype=torch.bool, device=self.device)
prediction_len = torch.zeros((b,), dtype=torch.int, device=self.device)
batch_size = input_tensor.size(0)
reached_end = torch.zeros(
(batch_size,), dtype=torch.bool, device=self.device
)
prediction_len = torch.zeros(
(batch_size,), dtype=torch.int, device=self.device
)
predicted_tokens = (
torch.ones((b, 1), dtype=torch.long, device=self.device) * start_token
torch.ones((batch_size, 1), dtype=torch.long, device=self.device)
* start_token
)
predicted_tokens_len = torch.ones(
(batch_size,), dtype=torch.int, device=self.device
)
predicted_tokens_len = torch.ones((b,), dtype=torch.int, device=self.device)
whole_output = list()
confidence_scores = list()
@@ -185,10 +207,11 @@ class DAN:
predicted_tokens = predicted_tokens[:, 1:]
prediction_len[torch.eq(reached_end, False)] = self.max_chars - 1
predicted_tokens = [
predicted_tokens[i, : prediction_len[i]] for i in range(b)
predicted_tokens[i, : prediction_len[i]] for i in range(batch_size)
]
confidence_scores = [
confidence_scores[i, : prediction_len[i]].tolist() for i in range(b)
confidence_scores[i, : prediction_len[i]].tolist()
for i in range(batch_size)
]
# Transform tokens to characters
@@ -198,34 +221,32 @@ class DAN:
logger.info("Images processed")
out = {"text": predicted_text}
out = {}
out["text"] = predicted_text
if confidences:
out["confidences"] = confidence_scores
if attentions:
out["attentions"] = attention_maps
if extract_objects:
out["objects"] = [
get_predicted_polygons_with_confidence(
predicted_text[i],
attention_maps[i],
confidence_scores[i],
attention_level,
input_sizes[i][0],
input_sizes[i][1],
threshold_method=threshold_method,
threshold_value=threshold_value,
word_separators=word_separators,
line_separators=line_separators,
)
for i in range(batch_size)
]
return out
def parse_delimiters(delimiters):
return re.compile(r"|".join(delimiters))
def compute_prob_by_separator(characters, probabilities, separator):
"""
Split text and confidences using separators and return a list of average confidence scores.
:param characters: list of characters.
:param probabilities: list of probabilities.
:param separators: regex for separators. Use parse_delimiters(["\n", " "]) for word confidences and parse_delimiters(["\n"]) for line confidences.
Returns a list confidence scores.
"""
# match anything except separators, get start and end index
pattern = re.compile(f"[^{separator.pattern}]+")
matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)]
# Iterate over text pieces and compute mean confidence
return [np.mean(probabilities[start:end]) for (start, end) in matches]
def run(
image,
model,
@@ -240,7 +261,28 @@ def run(
attention_map_scale,
word_separators,
line_separators,
predict_objects,
threshold_method,
threshold_value,
):
"""
Predict a single image save the output
:param image: Path to the image to predict.
:param model: Path to the model to use for prediction.
:param parameters: Path to the YAML parameters file.
:param charset: Path to the charset.
:param output: Path to the output folder where the results will be saved.
:param scale: Scaling factor to resize the image.
:param confidence_score: Whether to compute confidence score.
:param attention_map: Whether to plot the attention map.
:param attention_map_level: Level of objects to extract.
:param attention_map_scale: Scaling factor for the attention map.
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
"""
# Create output directory if necessary
if not os.path.exists(output):
os.mkdir(output)
@@ -261,41 +303,56 @@ def run(
input_tensor = input_tensor.to(device)
input_sizes = [im.shape[:2]]
# Parse delimiters to regex
word_separators = parse_delimiters(word_separators)
line_separators = parse_delimiters(line_separators)
# Predict
prediction = dan_model.predict(
input_tensor,
input_sizes,
confidences=confidence_score,
attentions=attention_map,
attention_level=attention_map_level,
extract_objects=predict_objects,
word_separators=word_separators,
line_separators=line_separators,
threshold_method=threshold_method,
threshold_value=threshold_value,
)
text = prediction["text"][0]
result = {"text": text}
# Parse delimiters to regex
word_separators = parse_delimiters(word_separators)
line_separators = parse_delimiters(line_separators)
result = {}
result["text"] = prediction["text"][0]
# Return extracted objects (coordinates, text, confidence)
if predict_objects:
result["objects"] = prediction["objects"][0]
# Average character-based confidence scores
# Return mean confidence score
if confidence_score:
result["confidences"] = {}
char_confidences = prediction["confidences"][0]
result["confidences"] = {"total": np.around(np.mean(char_confidences), 2)}
if "word" in confidence_score_levels:
word_probs = compute_prob_by_separator(
text, char_confidences, word_separators
)
result["confidences"].update({"word": round_floats(word_probs)})
if "line" in confidence_score_levels:
line_probs = compute_prob_by_separator(
text, char_confidences, line_separators
result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
for level in confidence_score_levels:
result["confidences"][level] = []
texts, confidences, _ = split_text_and_confidences(
prediction["text"][0],
char_confidences,
level,
word_separators,
line_separators,
)
result["confidences"].update({"line": round_floats(line_probs)})
if "char" in confidence_score_levels:
result["confidences"].update({"char": round_floats(char_confidences)})
for text, conf in zip(texts, confidences):
result["confidences"][level].append({"text": text, "confidence": conf})
# Save gif with attention map
if attention_map:
gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif"
logger.info(f"Creating attention GIF in {gif_filename}")
# this returns polygons but unused for now.
plot_attention(
image=im,
text=prediction["text"][0],
@@ -304,10 +361,13 @@ def run(
scale=attention_map_scale,
word_separators=word_separators,
line_separators=line_separators,
display_polygons=predict_objects,
threshold_method=threshold_method,
threshold_value=threshold_value,
outname=gif_filename,
)
result["attention_gif"] = gif_filename
json_filename = f"{output}/{image.stem}.json"
logger.info(f"Saving JSON prediction in {json_filename}")
save_json(json_filename, result)
save_json(Path(json_filename), result)
Loading