Skip to content
Snippets Groups Projects
Commit 774887b1 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Mélodie Boillet
Browse files

Batch predictions

parent 9c136e83
No related branches found
No related tags found
1 merge request!212Batch predictions
...@@ -140,4 +140,12 @@ def add_predict_parser(subcommands) -> None: ...@@ -140,4 +140,12 @@ def add_predict_parser(subcommands) -> None:
type=int, type=int,
required=False, required=False,
) )
parser.add_argument(
"--batch-size",
help="Size of prediction batches.",
type=int,
default=1,
required=False,
)
parser.set_defaults(func=run) parser.set_defaults(func=run)
...@@ -4,6 +4,7 @@ import re ...@@ -4,6 +4,7 @@ import re
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from torchvision.transforms.functional import to_pil_image
from dan import logger from dan import logger
...@@ -179,7 +180,7 @@ def blend_coverage(coverage_vector, image, mask, scale): ...@@ -179,7 +180,7 @@ def blend_coverage(coverage_vector, image, mask, scale):
blend = Image.composite(image, coverage_vector, mask) blend = Image.composite(image, coverage_vector, mask)
# Resize to save time # Resize to save time
blend = blend.resize((int(width * scale), int(height * scale)), Image.ANTIALIAS) blend = blend.resize((int(width * scale), int(height * scale)), Image.LANCZOS)
return blend return blend
...@@ -292,7 +293,7 @@ def plot_attention( ...@@ -292,7 +293,7 @@ def plot_attention(
): ):
""" """
Create a gif by blending attention maps to the image for each text piece (char, word or line) Create a gif by blending attention maps to the image for each text piece (char, word or line)
:param image: Input image in PIL format :param image: Input image as torch.Tensor
:param text: Text predicted by DAN :param text: Text predicted by DAN
:param weights: Attention weights of size (n_char, feature_height, feature_width) :param weights: Attention weights of size (n_char, feature_height, feature_width)
:param level: Level to display (must be in [char, word, line]) :param level: Level to display (must be in [char, word, line])
...@@ -303,12 +304,11 @@ def plot_attention( ...@@ -303,12 +304,11 @@ def plot_attention(
:param display_polygons: Whether to plot extracted polygons :param display_polygons: Whether to plot extracted polygons
""" """
height, width, _ = image.shape image = to_pil_image(image)
attention_map = [] attention_map = []
# Convert to PIL Image and create mask # Convert to PIL Image and create mask
mask = Image.new("L", (width, height), color=(110)) mask = Image.new("L", (image.width, image.height), color=(110))
image = Image.fromarray(image)
# Split text into characters, words or lines # Split text into characters, words or lines
text_list, offset = split_text(text, level, word_separators, line_separators) text_list, offset = split_text(text, level, word_separators, line_separators)
...@@ -320,7 +320,7 @@ def plot_attention( ...@@ -320,7 +320,7 @@ def plot_attention(
for text_piece in text_list: for text_piece in text_list:
# Accumulate weights for the current word/line and resize to original image size # Accumulate weights for the current word/line and resize to original image size
coverage_vector = compute_coverage( coverage_vector = compute_coverage(
text_piece, max_value, tot_len, weights, (width, height) text_piece, max_value, tot_len, weights, (image.width, image.height)
) )
# Get polygons if flag is set: # Get polygons if flag is set:
...@@ -333,7 +333,7 @@ def plot_attention( ...@@ -333,7 +333,7 @@ def plot_attention(
weights, weights,
threshold_method=threshold_method, threshold_method=threshold_method,
threshold_value=threshold_value, threshold_value=threshold_value,
size=(width, height), size=(image.width, image.height),
) )
if contour is not None: if contour is not None:
......
...@@ -20,7 +20,7 @@ from dan.predict.attention import ( ...@@ -20,7 +20,7 @@ from dan.predict.attention import (
split_text_and_confidences, split_text_and_confidences,
) )
from dan.transforms import get_normalization_transforms, get_preprocessing_transforms from dan.transforms import get_normalization_transforms, get_preprocessing_transforms
from dan.utils import ind_to_token, read_image from dan.utils import ind_to_token, list_to_batches, pad_images, read_image
class DAN: class DAN:
...@@ -248,8 +248,8 @@ class DAN: ...@@ -248,8 +248,8 @@ class DAN:
return out return out
def process_image( def process_batch(
image_path, image_batch,
dan_model, dan_model,
device, device,
output, output,
...@@ -264,20 +264,25 @@ def process_image( ...@@ -264,20 +264,25 @@ def process_image(
threshold_method, threshold_method,
threshold_value, threshold_value,
): ):
# Load image and pre-process it input_images, input_sizes = [], []
image = dan_model.preprocess(str(image_path)) logger.info("Loading images...")
logger.info("Image loaded.") for image_path in image_batch:
# Load image and pre-process it
image = dan_model.preprocess(str(image_path))
input_images.append(image)
input_sizes.append(image.shape[1:])
# Convert to tensor of size (batch_size, channel, height, width) with batch_size=1 # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
input_tensor = image.unsqueeze(0) input_tensor = pad_images(input_images).to(device)
input_tensor = input_tensor.to(device)
input_sizes = [image.shape[1:]] logger.info("Images preprocessed!")
# Parse delimiters to regex # Parse delimiters to regex
word_separators = parse_delimiters(word_separators) word_separators = parse_delimiters(word_separators)
line_separators = parse_delimiters(line_separators) line_separators = parse_delimiters(line_separators)
# Predict # Predict
logger.info("Predicting...")
prediction = dan_model.predict( prediction = dan_model.predict(
input_tensor, input_tensor,
input_sizes, input_sizes,
...@@ -290,70 +295,78 @@ def process_image( ...@@ -290,70 +295,78 @@ def process_image(
threshold_method=threshold_method, threshold_method=threshold_method,
threshold_value=threshold_value, threshold_value=threshold_value,
) )
logger.info("Prediction parsing...")
for idx, image_path in enumerate(image_batch):
predicted_text = prediction["text"][idx]
result = {"text": predicted_text}
# Return extracted objects (coordinates, text, confidence)
if predict_objects:
result["objects"] = prediction["objects"][idx]
# Return mean confidence score
if confidence_score:
result["confidences"] = {}
char_confidences = prediction["confidences"][idx]
# retrieve the index of the token ner
index = [
pos
for pos, char in enumerate(predicted_text)
if char in ["", "", "", ""]
]
result = {} # calculates scores by token
result["text"] = prediction["text"][0]
# Return extracted objects (coordinates, text, confidence)
if predict_objects:
result["objects"] = prediction["objects"][0]
# Return mean confidence score
if confidence_score:
result["confidences"] = {}
char_confidences = prediction["confidences"][0]
text = result["text"]
# retrieve the index of the token ner
index = [pos for pos, char in enumerate(text) if char in ["", "", "", ""]]
# calculates scores by token
result["confidences"]["by ner token"] = [
{
"text": f"{text[current: next_token]}".replace("\n", " "),
"confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token]), 2)}",
}
# We go up to -1 so that the last token matches until the end of the text
for current, next_token in pairwise(index + [-1])
]
result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
for level in confidence_score_levels:
result["confidences"][level] = []
texts, confidences, _ = split_text_and_confidences(
prediction["text"][0],
char_confidences,
level,
word_separators,
line_separators,
)
for text, conf in zip(texts, confidences): result["confidences"]["by ner token"] = [
result["confidences"][level].append({"text": text, "confidence": conf}) {
"text": f"{predicted_text[current: next_token]}".replace("\n", " "),
# Save gif with attention map "confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token]), 2)}",
if attention_map: }
gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif" # We go up to -1 so that the last token matches until the end of the text
logger.info(f"Creating attention GIF in {gif_filename}") for current, next_token in pairwise(index + [-1])
# this returns polygons but unused for now. ]
plot_attention( result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
image=image,
text=prediction["text"][0], for level in confidence_score_levels:
weights=prediction["attentions"][0], result["confidences"][level] = []
level=attention_map_level, texts, confidences, _ = split_text_and_confidences(
scale=attention_map_scale, predicted_text,
word_separators=word_separators, char_confidences,
line_separators=line_separators, level,
display_polygons=predict_objects, word_separators,
threshold_method=threshold_method, line_separators,
threshold_value=threshold_value, )
outname=gif_filename,
) for text, conf in zip(texts, confidences):
result["attention_gif"] = gif_filename result["confidences"][level].append(
{"text": text, "confidence": conf}
)
# Save gif with attention map
if attention_map:
attentions = prediction["attentions"][idx]
gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif"
logger.info(f"Creating attention GIF in {gif_filename}")
# this returns polygons but unused for now.
plot_attention(
image=input_tensor[idx],
text=predicted_text,
weights=attentions,
level=attention_map_level,
scale=attention_map_scale,
word_separators=word_separators,
line_separators=line_separators,
display_polygons=predict_objects,
threshold_method=threshold_method,
threshold_value=threshold_value,
outname=gif_filename,
)
result["attention_gif"] = gif_filename
json_filename = f"{output}/{image_path.stem}.json" json_filename = f"{output}/{image_path.stem}.json"
logger.info(f"Saving JSON prediction in {json_filename}") logger.info(f"Saving JSON prediction in {json_filename}")
save_json(Path(json_filename), result) save_json(Path(json_filename), result)
def run( def run(
...@@ -376,6 +389,7 @@ def run( ...@@ -376,6 +389,7 @@ def run(
threshold_value, threshold_value,
image_extension, image_extension,
gpu_device, gpu_device,
batch_size,
): ):
""" """
Predict a single image save the output Predict a single image save the output
...@@ -395,6 +409,7 @@ def run( ...@@ -395,6 +409,7 @@ def run(
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. :param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method. :param threshold_value: Thresholding value to use for the "simple" thresholding method.
:param gpu_device: Use a specific GPU if available. :param gpu_device: Use a specific GPU if available.
:param batch_size: Size of the batches for prediction.
""" """
# Create output directory if necessary # Create output directory if necessary
if not os.path.exists(output): if not os.path.exists(output):
...@@ -407,9 +422,9 @@ def run( ...@@ -407,9 +422,9 @@ def run(
dan_model.load(model, parameters, charset, mode="eval") dan_model.load(model, parameters, charset, mode="eval")
images = image_dir.rglob(f"*{image_extension}") if not image else [image] images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_name in images: for image_batch in list_to_batches(images, n=batch_size):
process_image( process_batch(
image_name, image_batch,
dan_model, dan_model,
device, device,
output, output,
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from itertools import islice
import torch import torch
import torchvision.io as torchvision import torchvision.io as torchvision
...@@ -72,3 +74,13 @@ def ind_to_token(labels, ind, oov_symbol=None): ...@@ -72,3 +74,13 @@ def ind_to_token(labels, ind, oov_symbol=None):
else: else:
res = [labels[i] for i in ind] res = [labels[i] for i in ind]
return "".join(res) return "".join(res)
def list_to_batches(iterable, n):
"Batch data into tuples of length n. The last batch may be shorter."
# list_to_batches('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json import json
import shutil
import pytest import pytest
...@@ -270,9 +271,172 @@ def test_run_prediction( ...@@ -270,9 +271,172 @@ def test_run_prediction(
threshold_value=0, threshold_value=0,
image_extension=None, image_extension=None,
gpu_device=None, gpu_device=None,
batch_size=1,
) )
with (tmp_path / image_name).with_suffix(".json").open("r") as json_file: prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text())
prediction = json.load(json_file)
assert prediction == expected_prediction assert prediction == expected_prediction
@pytest.mark.parametrize(
"image_names, confidence_score, temperature, expected_predictions",
(
(
["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
None,
1.0,
[{"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"}],
),
(
["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
["word"],
1.0,
[
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": {
"by ner token": [],
"total": 1.0,
"word": [
{"text": "ⓈBellisson", "confidence": 1.0},
{"text": "ⒻGeorges", "confidence": 1.0},
{"text": "Ⓑ91", "confidence": 1.0},
{"text": "ⓁP", "confidence": 1.0},
{"text": "ⒸM", "confidence": 1.0},
{"text": "ⓀCh", "confidence": 1.0},
{"text": "ⓄPlombier", "confidence": 1.0},
{"text": "ⓅPatron?12241", "confidence": 1.0},
],
},
}
],
),
(
[
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
],
["word"],
1.0,
[
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": {
"by ner token": [],
"total": 1.0,
"word": [
{"text": "ⓈBellisson", "confidence": 1.0},
{"text": "ⒻGeorges", "confidence": 1.0},
{"text": "Ⓑ91", "confidence": 1.0},
{"text": "ⓁP", "confidence": 1.0},
{"text": "ⒸM", "confidence": 1.0},
{"text": "ⓀCh", "confidence": 1.0},
{"text": "ⓄPlombier", "confidence": 1.0},
{"text": "ⓅPatron?12241", "confidence": 1.0},
],
},
},
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": {
"by ner token": [],
"total": 1.0,
"word": [
{"text": "ⓈBellisson", "confidence": 1.0},
{"text": "ⒻGeorges", "confidence": 1.0},
{"text": "Ⓑ91", "confidence": 1.0},
{"text": "ⓁP", "confidence": 1.0},
{"text": "ⒸM", "confidence": 1.0},
{"text": "ⓀCh", "confidence": 1.0},
{"text": "ⓄPlombier", "confidence": 1.0},
{"text": "ⓅPatron?12241", "confidence": 1.0},
],
},
},
],
),
(
["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
["word"],
1.0,
[
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": {
"by ner token": [],
"total": 1.0,
"word": [
{"text": "ⓈBellisson", "confidence": 1.0},
{"text": "ⒻGeorges", "confidence": 1.0},
{"text": "Ⓑ91", "confidence": 1.0},
{"text": "ⓁP", "confidence": 1.0},
{"text": "ⒸM", "confidence": 1.0},
{"text": "ⓀCh", "confidence": 1.0},
{"text": "ⓄPlombier", "confidence": 1.0},
{"text": "ⓅPatron?12241", "confidence": 1.0},
],
},
}
],
),
(
[
"2c242f5c-e979-43c4-b6f2-a6d4815b651d",
"ffdec445-7f14-4f5f-be44-68d0844d0df1",
],
False,
1.0,
[
{"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"},
{"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
],
),
),
)
@pytest.mark.parametrize("batch_size", [1, 2])
def test_run_prediction_batch(
image_names,
confidence_score,
temperature,
expected_predictions,
prediction_data_path,
batch_size,
tmp_path,
):
# Make tmpdir and copy needed images inside
image_dir = tmp_path / "images"
image_dir.mkdir()
for image_name in image_names:
shutil.copyfile(
(prediction_data_path / "images" / image_name).with_suffix(".png"),
(image_dir / image_name).with_suffix(".png"),
)
run_prediction(
image=None,
image_dir=image_dir,
model=prediction_data_path / "popp_line_model.pt",
parameters=prediction_data_path / "parameters.yml",
charset=prediction_data_path / "charset.pkl",
output=tmp_path,
confidence_score=True if confidence_score else False,
confidence_score_levels=confidence_score if confidence_score else [],
attention_map=False,
attention_map_level=None,
attention_map_scale=0.5,
word_separators=[" ", "\n"],
line_separators=["\n"],
temperature=temperature,
predict_objects=False,
threshold_method="otsu",
threshold_value=0,
image_extension=".png",
gpu_device=None,
batch_size=batch_size,
)
for image_name, expected_prediction in zip(image_names, expected_predictions):
prediction = json.loads(
(tmp_path / image_name).with_suffix(".json").read_text()
)
assert prediction == expected_prediction
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment