Skip to content
Snippets Groups Projects
Commit 69f278a9 authored by Tristan Faine's avatar Tristan Faine Committed by Yoann Schneider
Browse files

Add predicted objects to predict command

parent 614fa206
No related branches found
No related tags found
1 merge request!76Add predicted objects to predict command
...@@ -106,4 +106,22 @@ def add_predict_parser(subcommands) -> None: ...@@ -106,4 +106,22 @@ def add_predict_parser(subcommands) -> None:
help="String separators used to split text into lines.", help="String separators used to split text into lines.",
required=False, required=False,
) )
parser.add_argument(
"--predict-objects",
action="store_true",
help="Whether to output objects when plotting attention maps.",
required=False,
)
parser.add_argument(
"--threshold-method",
help="Thresholding method.",
choices=["otsu", "simple"],
default="otsu",
)
parser.add_argument(
"--threshold-value",
help="Thresholding value.",
type=int,
default=0,
)
parser.set_defaults(func=run) parser.set_defaults(func=run)
...@@ -6,18 +6,43 @@ import numpy as np ...@@ -6,18 +6,43 @@ import numpy as np
from PIL import Image from PIL import Image
from dan import logger from dan import logger
from dan.utils import round_floats
def split_text(text, level, word_separators, line_separators): def parse_delimiters(delimiters):
return re.compile(r"|".join(delimiters))
def compute_prob_by_separator(characters, probabilities, separator):
"""
Split text and confidences using separators and return a list of average confidence scores.
:param characters: list of characters.
:param probabilities: list of character probabilities.
:param separators: regex for separators. Use parse_delimiters(["\n", " "]) for word confidences and parse_delimiters(["\n"]) for line confidences.
Returns a list confidence scores.
"""
# match anything except separators, get start and end index
pattern = re.compile(f"[^{separator.pattern}]+")
matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)]
# Iterate over text pieces and compute mean confidence
probs = [np.mean(probabilities[start:end]) for (start, end) in matches]
texts = [characters[start:end] for (start, end) in matches]
return texts, probs
def split_text(text: str, level: str, word_separators, line_separators):
""" """
Split text into a list of characters, word, or lines. Split text into a list of characters, word, or lines.
:param text: Text prediction from DAN :param text: Text prediction from DAN
:param level: Level to visualize (char, word, line) :param level: Level to visualize from [char, word, line]
:param word_separators: List of word separators
:param line_separators: List of line separators
""" """
# split into characters
if level == "char": if level == "char":
text_split = list(text) text_split = list(text)
offset = 0 offset = 0
# split into words # split into words
elif level == "word": elif level == "word":
text_split = re.split(word_separators, text) text_split = re.split(word_separators, text)
...@@ -31,13 +56,89 @@ def split_text(text, level, word_separators, line_separators): ...@@ -31,13 +56,89 @@ def split_text(text, level, word_separators, line_separators):
return text_split, offset return text_split, offset
def compute_coverage(text: str, max_value: float, offset: int, attentions): def split_text_and_confidences(
text, confidences, level, word_separators, line_separators
):
"""
Split text into a list of characters, words or lines with corresponding confidences scores
:param text: Text prediction from DAN
:param confidences: Character confidences
:param level: Level to visualize from [char, word, line]
:param word_separators: List of word separators
:param line_separators: List of line separators
"""
if level == "char":
texts = list(text)
offset = 0
elif level == "word":
texts, probs = compute_prob_by_separator(text, confidences, word_separators)
offset = 1
elif level == "line":
texts, probs = compute_prob_by_separator(text, confidences, line_separators)
offset = 1
else:
logger.error("Level should be either 'char', 'word', or 'line'")
return texts, round_floats(probs), offset
def get_predicted_polygons_with_confidence(
text,
weights,
confidences,
level,
height,
width,
threshold_method="otsu",
threshold_value=0,
word_separators=["\n", " "],
line_separators=["\n"],
):
"""
Returns the polygons of each object of the current prediction
:param text: Text predicted by DAN
:param weights: Attention weights of size (n_char, feature_height, feature_width)
:param confidences: Character confidences
:param level: Level to display (must be in [char, word, line])
:param height: Original image height
:param width: Original image width
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"]
:param threshold_value: Thresholding value for the "simple" method.
:param word_separators: List of word separators
:param line_separators: List of line separators
"""
# Split text into characters, words or lines
text_list, confidence_list, offset = split_text_and_confidences(
text, confidences, level, word_separators, line_separators
)
max_value = weights.sum(0).max()
polygons = []
start_index = 0
for text_piece, confidence in zip(text_list, confidence_list):
start_index += len(text_piece) + offset
polygon, _ = get_polygon(
text_piece,
max_value,
offset,
weights,
threshold_method=threshold_method,
threshold_value=threshold_value,
size=(width, height),
)
polygon["text"] = text_piece
polygon["text_confidence"] = confidence
polygons.append(polygon)
return polygons
def compute_coverage(text: str, max_value: float, offset: int, attentions, size: tuple):
""" """
Aggregates attention maps for the current text piece (char, word, line) Aggregates attention maps for the current text piece (char, word, line)
:param text: Text piece selected with offset after splitting DAN prediction :param text: Text piece selected with offset after splitting DAN prediction
:param max_value: Maximum "attention intensity" for parts of a text piece, used for normalization :param max_value: Maximum "attention intensity" for parts of a text piece, used for normalization
:param offset: Offset value to get the relevant part of text piece :param offset: Offset value to get the relevant part of text piece
:param attentions: Attention weights of size (n_char, feature_height, feature_width) :param attentions: Attention weights of size (n_char, feature_height, feature_width)
:param size: Target size (width, height) to resize the coverage vector
""" """
_, height, width = attentions.shape _, height, width = attentions.shape
...@@ -49,9 +150,130 @@ def compute_coverage(text: str, max_value: float, offset: int, attentions): ...@@ -49,9 +150,130 @@ def compute_coverage(text: str, max_value: float, offset: int, attentions):
# Normalize coverage vector # Normalize coverage vector
coverage_vector = (coverage_vector / max_value * 255).astype(np.uint8) coverage_vector = (coverage_vector / max_value * 255).astype(np.uint8)
# Resize it
if size:
coverage_vector = cv2.resize(coverage_vector, size)
return coverage_vector return coverage_vector
def blend_coverage(coverage_vector, image, mask, scale):
"""
Blends current coverage_vector over original image, used to make an attention map.
:param coverage_vector: Aggregated attention weights of the current text piece, resized to image. size: (n_char, image_height, image_width)
:param image: Input image in PIL format
:param mask: Mask of the image (of any color)
:param scale: Scaling factor for the output gif image
"""
height, width = coverage_vector.shape
# Blend coverage vector with original image
blank_array = np.zeros((height, width)).astype(np.uint8)
coverage_vector = Image.fromarray(
np.stack([coverage_vector, blank_array, blank_array], axis=2), "RGB"
)
blend = Image.composite(image, coverage_vector, mask)
# Resize to save time
blend = blend.resize((int(width * scale), int(height * scale)), Image.ANTIALIAS)
return blend
def compute_contour_metrics(coverage_vector, contour):
"""
Compute the contours's area and the mean value inside it.
:param coverage_vector: Aggregated attention weights of the current text piece, resized to image. size: (n_char, image_height, image_width)
:param contour: Contour of the current attention blob
"""
# draw the contour zone
mask = np.zeros(coverage_vector.shape, dtype=np.uint8)
cv2.drawContours(mask, [contour], -1, (255), -1)
max_value = np.where(mask > 0, coverage_vector, 0).max() / 255
area = cv2.contourArea(contour)
return max_value, max_value * area
def polygon_to_bbx(polygon):
x, y, w, h = cv2.boundingRect(polygon)
return [[x, y], [x + w, y], [x + w, y + h], [x, y + h]]
def threshold(mask, threshold_method="otsu", threshold_value=0):
"""
Threshold a grayscale mask.
:param mask: a grayscale image (np.array)
:param threshold_method: method to be used for thresholding. Should be in ["otsu", "simple"].
:param threshold_value: the threshold value used for binarization (used for the "simple" method).
"""
min_kernel = 1
max_kernel = mask.shape[1] // 100
if threshold_method == "simple":
bin_mask = np.array(np.where(mask > threshold_value, 255, 0), dtype=np.uint8)
return np.asarray(bin_mask, dtype=np.uint8)
elif threshold_method == "otsu":
# Blur and apply Otsu thresholding
blur = cv2.GaussianBlur(mask, (15, 15), 0)
_, bin_mask = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Apply dilation
kernel_width = cv2.getStructuringElement(
cv2.MORPH_CROSS, (max_kernel, min_kernel)
)
dilated = cv2.dilate(bin_mask, kernel_width, iterations=3)
return np.asarray(dilated, dtype=np.uint8)
else:
raise NotImplementedError(f"Method {threshold_method} is not implemented.")
def get_polygon(
text, max_value, offset, weights, threshold_method, threshold_value, size=None
):
"""
Gets polygon associated with element of current text_piece, indexed by offset
:param text: Text piece selected with offset after splitting DAN prediction
:param max_value: Maximum "attention intensity" for parts of a text piece, used for normalization
:param offset: Offset value to get the relevant part of text piece
:param size: Target size (width, height) to resize the coverage vector
:param threshold_method: Binarization method to use (should be in ["simple", "otsu"])
:param threshold_value: Threshold value used for the "simple" binarization method
"""
# Compute coverage vector
coverage_vector = compute_coverage(text, max_value, offset, weights, size=size)
# Generate a binary image for the current channel.
bin_mask = threshold(
coverage_vector,
threshold_method=threshold_method,
threshold_value=threshold_value,
)
# Detect the objects contours
contours, _ = cv2.findContours(bin_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return {}, None
# Select best contour
metrics = [compute_contour_metrics(coverage_vector, cnt) for cnt in contours]
confidences, scores = map(list, zip(*metrics))
best_contour = contours[np.argmax(scores)]
confidence = round(confidences[np.argmax(scores)] / max_value, 2)
# Format for JSON
coord = polygon_to_bbx(np.squeeze(best_contour))
polygon = {
"confidence": confidence,
"polygon": coord,
}
simplified_contour = np.expand_dims(np.array(coord, dtype=np.int32), axis=1)
return polygon, simplified_contour
def plot_attention( def plot_attention(
image, image,
text, text,
...@@ -59,8 +281,11 @@ def plot_attention( ...@@ -59,8 +281,11 @@ def plot_attention(
level, level,
scale, scale,
outname, outname,
threshold_method="otsu",
threshold_value=0,
word_separators=["\n", " "], word_separators=["\n", " "],
line_separators=["\n"], line_separators=["\n"],
display_polygons=False,
): ):
""" """
Create a gif by blending attention maps to the image for each text piece (char, word or line) Create a gif by blending attention maps to the image for each text piece (char, word or line)
...@@ -70,6 +295,9 @@ def plot_attention( ...@@ -70,6 +295,9 @@ def plot_attention(
:param level: Level to display (must be in [char, word, line]) :param level: Level to display (must be in [char, word, line])
:param scale: Scaling factor for the output gif image :param scale: Scaling factor for the output gif image
:param outname: Name of the gif image :param outname: Name of the gif image
:param word_separators: List of word separators
:param line_separators: List of line separators
:param display_polygons: Whether to plot extracted polygons
""" """
height, width, _ = image.shape height, width, _ = image.shape
...@@ -84,27 +312,35 @@ def plot_attention( ...@@ -84,27 +312,35 @@ def plot_attention(
# Iterate on characters, words or lines # Iterate on characters, words or lines
tot_len = 0 tot_len = 0
max_value = weights.sum(0).max() max_value = weights.sum(0).max()
for text_piece in text_list: for text_piece in text_list:
# Accumulate weights for the current word/line and resize to original image size # Accumulate weights for the current word/line and resize to original image size
coverage_vector = compute_coverage(text_piece, max_value, tot_len, weights) coverage_vector = compute_coverage(
coverage_vector = cv2.resize(coverage_vector, (width, height)) text_piece, max_value, tot_len, weights, (width, height)
)
# Get polygons if flag is set:
if display_polygons:
# draw the contour
_, contour = get_polygon(
text_piece,
max_value,
tot_len,
weights,
threshold_method=threshold_method,
threshold_value=threshold_value,
size=(width, height),
)
if contour is not None:
cv2.drawContours(coverage_vector, [contour], 0, (255), 5)
# Keep track of text length # Keep track of text length
tot_len += len(text_piece) + offset tot_len += len(text_piece) + offset
# Blend coverage vector with original image # Blend coverage vector with original image
blank_array = np.zeros((height, width)).astype(np.uint8) attention_map.append(blend_coverage(coverage_vector, image, mask, scale))
coverage_vector = Image.fromarray(
np.stack([coverage_vector, blank_array, blank_array], axis=2), "RGB"
)
blend = Image.composite(image, coverage_vector, mask)
# Resize to save time
blend = blend.resize((int(width * scale), int(height * scale)), Image.ANTIALIAS)
attention_map.append(blend)
attention_map[0].save( attention_map[0].save(
outname, outname,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import os import os
import pickle import pickle
import re from pathlib import Path
import cv2 import cv2
import numpy as np import numpy as np
...@@ -14,8 +14,13 @@ from dan.datasets.extract.utils import save_json ...@@ -14,8 +14,13 @@ from dan.datasets.extract.utils import save_json
from dan.decoder import GlobalHTADecoder from dan.decoder import GlobalHTADecoder
from dan.models import FCN_Encoder from dan.models import FCN_Encoder
from dan.ocr.utils import LM_ind_to_str from dan.ocr.utils import LM_ind_to_str
from dan.predict.attention import plot_attention from dan.predict.attention import (
from dan.utils import read_image, round_floats get_predicted_polygons_with_confidence,
parse_delimiters,
plot_attention,
split_text_and_confidences,
)
from dan.utils import read_image
class DAN: class DAN:
...@@ -92,7 +97,13 @@ class DAN: ...@@ -92,7 +97,13 @@ class DAN:
input_sizes, input_sizes,
confidences=False, confidences=False,
attentions=False, attentions=False,
attention_level=False,
extract_objects=False,
word_separators=["\n", " "],
line_separators=["\n"],
start_token=None, start_token=None,
threshold_method="otsu",
threshold_value=0,
): ):
""" """
Run prediction on an input image. Run prediction on an input image.
...@@ -100,6 +111,10 @@ class DAN: ...@@ -100,6 +111,10 @@ class DAN:
:param input_sizes: The original images sizes. :param input_sizes: The original images sizes.
:param confidences: Return the characters probabilities. :param confidences: Return the characters probabilities.
:param attentions: Return characters attention weights. :param attentions: Return characters attention weights.
:param attention_level: Level of text pieces (must be in [char, word, line])
:param extract_objects: Whether to extract polygons' coordinates.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
""" """
input_tensor = input_tensor.to(self.device) input_tensor = input_tensor.to(self.device)
...@@ -110,13 +125,20 @@ class DAN: ...@@ -110,13 +125,20 @@ class DAN:
# Run the prediction. # Run the prediction.
with torch.no_grad(): with torch.no_grad():
b = input_tensor.size(0) batch_size = input_tensor.size(0)
reached_end = torch.zeros((b,), dtype=torch.bool, device=self.device) reached_end = torch.zeros(
prediction_len = torch.zeros((b,), dtype=torch.int, device=self.device) (batch_size,), dtype=torch.bool, device=self.device
)
prediction_len = torch.zeros(
(batch_size,), dtype=torch.int, device=self.device
)
predicted_tokens = ( predicted_tokens = (
torch.ones((b, 1), dtype=torch.long, device=self.device) * start_token torch.ones((batch_size, 1), dtype=torch.long, device=self.device)
* start_token
)
predicted_tokens_len = torch.ones(
(batch_size,), dtype=torch.int, device=self.device
) )
predicted_tokens_len = torch.ones((b,), dtype=torch.int, device=self.device)
whole_output = list() whole_output = list()
confidence_scores = list() confidence_scores = list()
...@@ -185,10 +207,11 @@ class DAN: ...@@ -185,10 +207,11 @@ class DAN:
predicted_tokens = predicted_tokens[:, 1:] predicted_tokens = predicted_tokens[:, 1:]
prediction_len[torch.eq(reached_end, False)] = self.max_chars - 1 prediction_len[torch.eq(reached_end, False)] = self.max_chars - 1
predicted_tokens = [ predicted_tokens = [
predicted_tokens[i, : prediction_len[i]] for i in range(b) predicted_tokens[i, : prediction_len[i]] for i in range(batch_size)
] ]
confidence_scores = [ confidence_scores = [
confidence_scores[i, : prediction_len[i]].tolist() for i in range(b) confidence_scores[i, : prediction_len[i]].tolist()
for i in range(batch_size)
] ]
# Transform tokens to characters # Transform tokens to characters
...@@ -198,34 +221,32 @@ class DAN: ...@@ -198,34 +221,32 @@ class DAN:
logger.info("Images processed") logger.info("Images processed")
out = {"text": predicted_text} out = {}
out["text"] = predicted_text
if confidences: if confidences:
out["confidences"] = confidence_scores out["confidences"] = confidence_scores
if attentions: if attentions:
out["attentions"] = attention_maps out["attentions"] = attention_maps
if extract_objects:
out["objects"] = [
get_predicted_polygons_with_confidence(
predicted_text[i],
attention_maps[i],
confidence_scores[i],
attention_level,
input_sizes[i][0],
input_sizes[i][1],
threshold_method=threshold_method,
threshold_value=threshold_value,
word_separators=word_separators,
line_separators=line_separators,
)
for i in range(batch_size)
]
return out return out
def parse_delimiters(delimiters):
return re.compile(r"|".join(delimiters))
def compute_prob_by_separator(characters, probabilities, separator):
"""
Split text and confidences using separators and return a list of average confidence scores.
:param characters: list of characters.
:param probabilities: list of probabilities.
:param separators: regex for separators. Use parse_delimiters(["\n", " "]) for word confidences and parse_delimiters(["\n"]) for line confidences.
Returns a list confidence scores.
"""
# match anything except separators, get start and end index
pattern = re.compile(f"[^{separator.pattern}]+")
matches = [(m.start(), m.end()) for m in re.finditer(pattern, characters)]
# Iterate over text pieces and compute mean confidence
return [np.mean(probabilities[start:end]) for (start, end) in matches]
def run( def run(
image, image,
model, model,
...@@ -240,7 +261,28 @@ def run( ...@@ -240,7 +261,28 @@ def run(
attention_map_scale, attention_map_scale,
word_separators, word_separators,
line_separators, line_separators,
predict_objects,
threshold_method,
threshold_value,
): ):
"""
Predict a single image save the output
:param image: Path to the image to predict.
:param model: Path to the model to use for prediction.
:param parameters: Path to the YAML parameters file.
:param charset: Path to the charset.
:param output: Path to the output folder where the results will be saved.
:param scale: Scaling factor to resize the image.
:param confidence_score: Whether to compute confidence score.
:param attention_map: Whether to plot the attention map.
:param attention_map_level: Level of objects to extract.
:param attention_map_scale: Scaling factor for the attention map.
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param predict_objects: Whether to extract objects.
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
"""
# Create output directory if necessary # Create output directory if necessary
if not os.path.exists(output): if not os.path.exists(output):
os.mkdir(output) os.mkdir(output)
...@@ -261,41 +303,56 @@ def run( ...@@ -261,41 +303,56 @@ def run(
input_tensor = input_tensor.to(device) input_tensor = input_tensor.to(device)
input_sizes = [im.shape[:2]] input_sizes = [im.shape[:2]]
# Parse delimiters to regex
word_separators = parse_delimiters(word_separators)
line_separators = parse_delimiters(line_separators)
# Predict # Predict
prediction = dan_model.predict( prediction = dan_model.predict(
input_tensor, input_tensor,
input_sizes, input_sizes,
confidences=confidence_score, confidences=confidence_score,
attentions=attention_map, attentions=attention_map,
attention_level=attention_map_level,
extract_objects=predict_objects,
word_separators=word_separators,
line_separators=line_separators,
threshold_method=threshold_method,
threshold_value=threshold_value,
) )
text = prediction["text"][0]
result = {"text": text}
# Parse delimiters to regex result = {}
word_separators = parse_delimiters(word_separators) result["text"] = prediction["text"][0]
line_separators = parse_delimiters(line_separators)
# Return extracted objects (coordinates, text, confidence)
if predict_objects:
result["objects"] = prediction["objects"][0]
# Average character-based confidence scores # Return mean confidence score
if confidence_score: if confidence_score:
result["confidences"] = {}
char_confidences = prediction["confidences"][0] char_confidences = prediction["confidences"][0]
result["confidences"] = {"total": np.around(np.mean(char_confidences), 2)} result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
if "word" in confidence_score_levels:
word_probs = compute_prob_by_separator( for level in confidence_score_levels:
text, char_confidences, word_separators result["confidences"][level] = []
) texts, confidences, _ = split_text_and_confidences(
result["confidences"].update({"word": round_floats(word_probs)}) prediction["text"][0],
if "line" in confidence_score_levels: char_confidences,
line_probs = compute_prob_by_separator( level,
text, char_confidences, line_separators word_separators,
line_separators,
) )
result["confidences"].update({"line": round_floats(line_probs)})
if "char" in confidence_score_levels: for text, conf in zip(texts, confidences):
result["confidences"].update({"char": round_floats(char_confidences)}) result["confidences"][level].append({"text": text, "confidence": conf})
# Save gif with attention map # Save gif with attention map
if attention_map: if attention_map:
gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif" gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif"
logger.info(f"Creating attention GIF in {gif_filename}") logger.info(f"Creating attention GIF in {gif_filename}")
# this returns polygons but unused for now.
plot_attention( plot_attention(
image=im, image=im,
text=prediction["text"][0], text=prediction["text"][0],
...@@ -304,10 +361,13 @@ def run( ...@@ -304,10 +361,13 @@ def run(
scale=attention_map_scale, scale=attention_map_scale,
word_separators=word_separators, word_separators=word_separators,
line_separators=line_separators, line_separators=line_separators,
display_polygons=predict_objects,
threshold_method=threshold_method,
threshold_value=threshold_value,
outname=gif_filename, outname=gif_filename,
) )
result["attention_gif"] = gif_filename result["attention_gif"] = gif_filename
json_filename = f"{output}/{image.stem}.json" json_filename = f"{output}/{image.stem}.json"
logger.info(f"Saving JSON prediction in {json_filename}") logger.info(f"Saving JSON prediction in {json_filename}")
save_json(json_filename, result) save_json(Path(json_filename), result)
Image diff could not be displayed: it is too large. Options to address this: view the blob.
...@@ -13,12 +13,15 @@ Use the `teklia-dan predict` command to predict a trained DAN model on an image. ...@@ -13,12 +13,15 @@ Use the `teklia-dan predict` command to predict a trained DAN model on an image.
| `--output` | Path to the output folder. Results will be saved in this directory. | `Path` | | | `--output` | Path to the output folder. Results will be saved in this directory. | `Path` | |
| `--scale` | Image scaling factor before feeding it to DAN. | `float` | `1.0` | | `--scale` | Image scaling factor before feeding it to DAN. | `float` | `1.0` |
| `--confidence-score` | Whether to return confidence scores. | `bool` | `False` | | `--confidence-score` | Whether to return confidence scores. | `bool` | `False` |
| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str` | | | `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str` | |
| `--attention-map` | Whether to plot attention maps. | `bool` | `False` | | `--attention-map` | Whether to plot attention maps. | `bool` | `False` |
| `--attention-map-scale` | Image scaling factor before creating the GIF. | `float` | `0.5` | | `--attention-map-scale` | Image scaling factor before creating the GIF. | `float` | `0.5` |
| `--attention-map-level` | Level to plot the attention maps. Should be in `["line", "word", "char"]`. | `str` | `"line"` | | `--attention-map-level` | Level to plot the attention maps. Should be in `["line", "word", "char"]`. | `str` | `"line"` |
| `--predict-objects` | Whether to return polygons coordinates. | `bool` | `False` |
| `--word-separators` | List of word separators. | `list` | `[" ", "\n"]` | | `--word-separators` | List of word separators. | `list` | `[" ", "\n"]` |
| `--line-separators` | List of line separators. | `list` | `["\n"]` | | `--line-separators` | List of line separators. | `list` | `["\n"]` |
| `--threshold-method` | Method to use for attention mask thresholding. Should be in `["otsu", "simple"]`. | `str` | `"otsu"` |
| `--threshold-value ` | Threshold to use for the "simple" thresholding method. | `int` | `0` |
## Examples ## Examples
...@@ -100,3 +103,57 @@ It will create the following JSON file named `dan_humu_page/predict/example.json ...@@ -100,3 +103,57 @@ It will create the following JSON file named `dan_humu_page/predict/example.json
} }
``` ```
<img src="../../assets/example_word.gif" > <img src="../../assets/example_word.gif" >
### Predict with line-level attention maps and extract polygons
To run a prediction, plot line-level attention maps, and extract polygons, run this command:
```shell
teklia-dan predict \
--image dan_humu_page/example.jpg \
--model dan_humu_page/model.pt \
--parameters dan_humu_page/parameters.yml \
--charset dan_humu_page/charset.pkl \
--output dan_humu_page/predict/ \
--scale 0.5 \
--attention-map \
--predict-objects \
--threshold-method otsu
```
It will create the following JSON file named `dan_humu_page/predict/example.json` and a GIF showing a line-level attention map with extracted polygons `dan_humu_page/predict/example_line.gif`
```json
{
"text": "Oslo\n39 \nOresden den 24te Rasser!\nH\u00f8jst\u00e6redesherr Hartvig - assert!\nUllereder fra den f\u00f8rste tide da\njeg havder den tilfredsstillelser at vide den ar-\ndistiske ledelser af Kristiania theater i Deres\nhronder, har jeg g\u00e5t hernede med et stille\nh\u00e5b om fra Dem at modtage et forelag, sig -\nsende tils at lade \"K\u00e6rlighedens \u00abKomedie\u00bb\nopf\u00f8re fore det norske purblikum.\nEt s\u00e5dant forslag er imidlertid, imod\nforventning; ikke fremkommet, og jeg n\u00f8des der-\nfor tils self at grivbe initiativet, hvilket hervede\nsker, idet jeg\nbeder\nbet\nragte stigkket some ved denne\nskrivelse officielde indleveret til theatret. No-\nget exemplar af bogen vedlagger jeg ikke da\ndenne (i 2den udgave) med Lethed kan er -\nholdet deroppe.\nDe bet\u00e6nkeligheder, jeg i sin tid n\u00e6-\nrede mod stykkets opf\u00f8relse, er for l\u00e6nge si -\ndem forsvundne. Af mange begn er jeg kom-\nmen til den overbevisning at almenlreden\naru har f\u00e5tt sine \u00f8gne opladte for den sand -\nMed at dette arbejde i sin indersten id\u00e9 hviler\np\u00e5 et ubedinget meralsk grundlag, og brad\nstykkets hele kunstneriske struktuve ang\u00e5r,",
"objects": [
{
"confidence": 0.68,
"polygon": [
[
264,
118
],
[
410,
118
],
[
410,
185
],
[
264,
185
]
],
"text": "Oslo",
"text_confidence": 0.8
},
...
"attention_gif": "dan_humu_page/predict/example_line.gif"
}
```
<img src="../../assets/example_line_polygon.gif" >
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment