Skip to content
Snippets Groups Projects
Verified Commit 14a829a4 authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Apply 774887b1

parent a6d52a94
No related branches found
No related tags found
No related merge requests found
......@@ -140,4 +140,12 @@ def add_predict_parser(subcommands) -> None:
type=int,
required=False,
)
parser.add_argument(
"--batch-size",
help="Size of prediction batches.",
type=int,
default=1,
required=False,
)
parser.set_defaults(func=run)
......@@ -4,6 +4,7 @@ import re
import cv2
import numpy as np
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from dan import logger
......@@ -179,7 +180,7 @@ def blend_coverage(coverage_vector, image, mask, scale):
blend = Image.composite(image, coverage_vector, mask)
# Resize to save time
blend = blend.resize((int(width * scale), int(height * scale)), Image.ANTIALIAS)
blend = blend.resize((int(width * scale), int(height * scale)), Image.LANCZOS)
return blend
......@@ -292,7 +293,7 @@ def plot_attention(
):
"""
Create a gif by blending attention maps to the image for each text piece (char, word or line)
:param image: Input image in PIL format
:param image: Input image as torch.Tensor
:param text: Text predicted by DAN
:param weights: Attention weights of size (n_char, feature_height, feature_width)
:param level: Level to display (must be in [char, word, line])
......@@ -302,13 +303,11 @@ def plot_attention(
:param line_separators: List of line separators
:param display_polygons: Whether to plot extracted polygons
"""
height, width, _ = image.shape
image = to_pil_image(image)
attention_map = []
# Convert to PIL Image and create mask
mask = Image.new("L", (width, height), color=(110))
image = Image.fromarray(image)
mask = Image.new("L", (image.width, image.height), color=(110))
# Split text into characters, words or lines
text_list, offset = split_text(text, level, word_separators, line_separators)
......@@ -320,7 +319,7 @@ def plot_attention(
for text_piece in text_list:
# Accumulate weights for the current word/line and resize to original image size
coverage_vector = compute_coverage(
text_piece, max_value, tot_len, weights, (width, height)
text_piece, max_value, tot_len, weights, (image.width, image.height)
)
# Get polygons if flag is set:
......@@ -333,7 +332,7 @@ def plot_attention(
weights,
threshold_method=threshold_method,
threshold_value=threshold_value,
size=(width, height),
size=(image.width, image.height),
)
if contour is not None:
......
......@@ -20,7 +20,7 @@ from dan.predict.attention import (
split_text_and_confidences,
)
from dan.transforms import get_preprocessing_transforms
from dan.utils import ind_to_token, read_image
from dan.utils import ind_to_token, list_to_batches, pad_images, read_image
class DAN:
......@@ -255,8 +255,8 @@ class DAN:
return out
def process_image(
image_path,
def process_batch(
image_batch,
dan_model,
device,
output,
......@@ -271,20 +271,24 @@ def process_image(
threshold_method,
threshold_value,
):
# Load image and pre-process it
image = dan_model.preprocess(str(image_path))
logger.info("Image loaded.")
input_images, input_sizes = [], []
logger.info("Loading images...")
for image_path in image_batch:
# Load image and pre-process it
image = dan_model.preprocess(str(image_path))
input_images.append(image)
input_sizes.append(image.shape[1:])
# Convert to tensor of size (batch_size, channel, height, width) with batch_size=1
input_tensor = image.unsqueeze(0)
input_tensor = input_tensor.to(device)
input_sizes = [image.shape[1:]]
input_tensor = pad_images(input_images).to(device)
logger.info("Images preprocessed!")
# Parse delimiters to regex
word_separators = parse_delimiters(word_separators)
line_separators = parse_delimiters(line_separators)
# Predict
logger.info("Predicting...")
prediction = dan_model.predict(
input_tensor,
input_sizes,
......@@ -298,69 +302,76 @@ def process_image(
threshold_value=threshold_value,
)
result = {}
result["text"] = prediction["text"][0]
# Return extracted objects (coordinates, text, confidence)
if predict_objects:
result["objects"] = prediction["objects"][0]
# Return mean confidence score
if confidence_score:
result["confidences"] = {}
char_confidences = prediction["confidences"][0]
text = result["text"]
# retrieve the index of the token ner
index = [pos for pos, char in enumerate(text) if char in ["", "", "", ""]]
# calculates scores by token
result["confidences"]["by ner token"] = [
{
"text": f"{text[current: next_token]}".replace("\n", " "),
"confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token]), 2)}",
}
# We go up to -1 so that the last token matches until the end of the text
for current, next_token in pairwise(index + [-1])
]
result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
for level in confidence_score_levels:
result["confidences"][level] = []
texts, confidences, _ = split_text_and_confidences(
prediction["text"][0],
char_confidences,
level,
word_separators,
line_separators,
)
logger.info("Prediction parsing...")
for idx, image_path in enumerate(image_batch):
predicted_text = prediction["text"][idx]
result = {"text": predicted_text}
# Return extracted objects (coordinates, text, confidence)
if predict_objects:
result["objects"] = prediction["objects"][idx]
# Return mean confidence score
if confidence_score:
result["confidences"] = {}
char_confidences = prediction["confidences"][idx]
# retrieve the index of the token ner
index = [
pos
for pos, char in enumerate(predicted_text)
if char in ["", "", "", ""]
]
for text, conf in zip(texts, confidences):
result["confidences"][level].append({"text": text, "confidence": conf})
# Save gif with attention map
if attention_map:
gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif"
logger.info(f"Creating attention GIF in {gif_filename}")
# this returns polygons but unused for now.
plot_attention(
image=image,
text=prediction["text"][0],
weights=prediction["attentions"][0],
level=attention_map_level,
scale=attention_map_scale,
word_separators=word_separators,
line_separators=line_separators,
display_polygons=predict_objects,
threshold_method=threshold_method,
threshold_value=threshold_value,
outname=gif_filename,
)
result["attention_gif"] = gif_filename
# calculates scores by token
result["confidences"]["by ner token"] = [
{
"text": f"{predicted_text[current: next_token]}".replace("\n", " "),
"confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token]), 2)}",
}
# We go up to -1 so that the last token matches until the end of the text
for current, next_token in pairwise(index + [-1])
]
result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
for level in confidence_score_levels:
result["confidences"][level] = []
texts, confidences, _ = split_text_and_confidences(
predicted_text,
char_confidences,
level,
word_separators,
line_separators,
)
for text, conf in zip(texts, confidences):
result["confidences"][level].append(
{"text": text, "confidence": conf}
)
# Save gif with attention map
if attention_map:
attentions = prediction["attentions"][idx]
gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif"
logger.info(f"Creating attention GIF in {gif_filename}")
# this returns polygons but unused for now.
plot_attention(
image=input_tensor[idx],
text=predicted_text,
weights=attentions,
level=attention_map_level,
scale=attention_map_scale,
word_separators=word_separators,
line_separators=line_separators,
display_polygons=predict_objects,
threshold_method=threshold_method,
threshold_value=threshold_value,
outname=gif_filename,
)
result["attention_gif"] = gif_filename
json_filename = f"{output}/{image_path.stem}.json"
logger.info(f"Saving JSON prediction in {json_filename}")
save_json(Path(json_filename), result)
json_filename = f"{output}/{image_path.stem}.json"
logger.info(f"Saving JSON prediction in {json_filename}")
save_json(Path(json_filename), result)
def run(
......@@ -383,6 +394,7 @@ def run(
threshold_value,
image_extension,
gpu_device,
batch_size,
):
"""
Predict a single image save the output
......@@ -402,6 +414,7 @@ def run(
:param threshold_method: Thresholding method. Should be in ["otsu", "simple"].
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
:param gpu_device: Use a specific GPU if available.
:param batch_size: Size of the batches for prediction.
"""
# Create output directory if necessary
if not os.path.exists(output):
......@@ -414,9 +427,9 @@ def run(
dan_model.load(model, parameters, charset, mode="eval")
images = image_dir.rglob(f"*{image_extension}") if not image else [image]
for image_name in images:
process_image(
image_name,
for image_batch in list_to_batches(images, n=batch_size):
process_batch(
image_batch,
dan_model,
device,
output,
......
# -*- coding: utf-8 -*-
from itertools import islice
import torch
import torchvision.io as torchvision
......@@ -68,3 +70,13 @@ def ind_to_token(labels, ind, oov_symbol=None):
else:
res = [labels[i] for i in ind]
return "".join(res)
def list_to_batches(iterable, n):
"Batch data into tuples of length n. The last batch may be shorter."
# list_to_batches('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch
# -*- coding: utf-8 -*-
import json
import shutil
import pytest
......@@ -270,9 +271,173 @@ def test_run_prediction(
threshold_value=0,
image_extension=None,
gpu_device=None,
batch_size=1,
)
with (tmp_path / image_name).with_suffix(".json").open("r") as json_file:
prediction = json.load(json_file)
prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text())
assert prediction == expected_prediction
@pytest.mark.parametrize(
"image_names, confidence_score, temperature, expected_predictions",
(
(
["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
None,
1.0,
[{"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241"}],
),
(
["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
["word"],
1.0,
[
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": {
"by ner token": [],
"total": 1.0,
"word": [
{"text": "ⓈBellisson", "confidence": 1.0},
{"text": "ⒻGeorges", "confidence": 1.0},
{"text": "Ⓑ91", "confidence": 1.0},
{"text": "ⓁP", "confidence": 1.0},
{"text": "ⒸM", "confidence": 1.0},
{"text": "ⓀCh", "confidence": 1.0},
{"text": "ⓄPlombier", "confidence": 1.0},
{"text": "ⓅPatron?12241", "confidence": 1.0},
],
},
}
],
),
(
[
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
],
["word"],
1.0,
[
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": {
"by ner token": [],
"total": 1.0,
"word": [
{"text": "ⓈBellisson", "confidence": 1.0},
{"text": "ⒻGeorges", "confidence": 1.0},
{"text": "Ⓑ91", "confidence": 1.0},
{"text": "ⓁP", "confidence": 1.0},
{"text": "ⒸM", "confidence": 1.0},
{"text": "ⓀCh", "confidence": 1.0},
{"text": "ⓄPlombier", "confidence": 1.0},
{"text": "ⓅPatron?12241", "confidence": 1.0},
],
},
},
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": {
"by ner token": [],
"total": 1.0,
"word": [
{"text": "ⓈBellisson", "confidence": 1.0},
{"text": "ⒻGeorges", "confidence": 1.0},
{"text": "Ⓑ91", "confidence": 1.0},
{"text": "ⓁP", "confidence": 1.0},
{"text": "ⒸM", "confidence": 1.0},
{"text": "ⓀCh", "confidence": 1.0},
{"text": "ⓄPlombier", "confidence": 1.0},
{"text": "ⓅPatron?12241", "confidence": 1.0},
],
},
},
],
),
(
["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"],
["word"],
1.0,
[
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": {
"by ner token": [],
"total": 1.0,
"word": [
{"text": "ⓈBellisson", "confidence": 1.0},
{"text": "ⒻGeorges", "confidence": 1.0},
{"text": "Ⓑ91", "confidence": 1.0},
{"text": "ⓁP", "confidence": 1.0},
{"text": "ⒸM", "confidence": 1.0},
{"text": "ⓀCh", "confidence": 1.0},
{"text": "ⓄPlombier", "confidence": 1.0},
{"text": "ⓅPatron?12241", "confidence": 1.0},
],
},
}
],
),
(
[
"2c242f5c-e979-43c4-b6f2-a6d4815b651d",
"ffdec445-7f14-4f5f-be44-68d0844d0df1",
],
False,
1.0,
[
{"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31"},
{"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère"},
],
),
),
)
@pytest.mark.parametrize("batch_size", [1, 2])
def test_run_prediction_batch(
image_names,
confidence_score,
temperature,
expected_predictions,
prediction_data_path,
batch_size,
tmp_path,
):
# Make tmpdir and copy needed images inside
image_dir = tmp_path / "images"
image_dir.mkdir()
for image_name in image_names:
shutil.copyfile(
(prediction_data_path / "images" / image_name).with_suffix(".png"),
(image_dir / image_name).with_suffix(".png"),
)
run_prediction(
image=None,
image_dir=image_dir,
model=prediction_data_path / "popp_line_model.pt",
parameters=prediction_data_path / "parameters.yml",
charset=prediction_data_path / "charset.pkl",
output=tmp_path,
confidence_score=True if confidence_score else False,
confidence_score_levels=confidence_score if confidence_score else [],
attention_map=False,
attention_map_level=None,
attention_map_scale=0.5,
word_separators=[" ", "\n"],
line_separators=["\n"],
temperature=temperature,
predict_objects=False,
threshold_method="otsu",
threshold_value=0,
image_extension=".png",
gpu_device=None,
batch_size=batch_size,
)
for image_name, expected_prediction in zip(image_names, expected_predictions):
prediction = json.loads(
(tmp_path / image_name).with_suffix(".json").read_text()
)
assert prediction == expected_prediction
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment