diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py index 1b73e3b8ec20ca7ff87d42947e1469933683cdcb..2929f4db381bdaa27871084381d9b78943e07a21 100644 --- a/dan/predict/__init__.py +++ b/dan/predict/__init__.py @@ -140,4 +140,12 @@ def add_predict_parser(subcommands) -> None: type=int, required=False, ) + parser.add_argument( + "--batch-size", + help="Size of prediction batches.", + type=int, + default=1, + required=False, + ) + parser.set_defaults(func=run) diff --git a/dan/predict/attention.py b/dan/predict/attention.py index 0b06425fefe17a69fa972c3273c127ceb35fee3d..dc8c93972923e079140f780dff739b40e8a91ae4 100644 --- a/dan/predict/attention.py +++ b/dan/predict/attention.py @@ -4,6 +4,7 @@ import re import cv2 import numpy as np from PIL import Image +from torchvision.transforms.functional import to_pil_image from dan import logger @@ -179,7 +180,7 @@ def blend_coverage(coverage_vector, image, mask, scale): blend = Image.composite(image, coverage_vector, mask) # Resize to save time - blend = blend.resize((int(width * scale), int(height * scale)), Image.ANTIALIAS) + blend = blend.resize((int(width * scale), int(height * scale)), Image.LANCZOS) return blend @@ -292,7 +293,7 @@ def plot_attention( ): """ Create a gif by blending attention maps to the image for each text piece (char, word or line) - :param image: Input image in PIL format + :param image: Input image as torch.Tensor :param text: Text predicted by DAN :param weights: Attention weights of size (n_char, feature_height, feature_width) :param level: Level to display (must be in [char, word, line]) @@ -303,12 +304,11 @@ def plot_attention( :param display_polygons: Whether to plot extracted polygons """ - height, width, _ = image.shape + image = to_pil_image(image) attention_map = [] # Convert to PIL Image and create mask - mask = Image.new("L", (width, height), color=(110)) - image = Image.fromarray(image) + mask = Image.new("L", (image.width, image.height), color=(110)) # Split text into characters, words or lines text_list, offset = split_text(text, level, word_separators, line_separators) @@ -320,7 +320,7 @@ def plot_attention( for text_piece in text_list: # Accumulate weights for the current word/line and resize to original image size coverage_vector = compute_coverage( - text_piece, max_value, tot_len, weights, (width, height) + text_piece, max_value, tot_len, weights, (image.width, image.height) ) # Get polygons if flag is set: @@ -333,7 +333,7 @@ def plot_attention( weights, threshold_method=threshold_method, threshold_value=threshold_value, - size=(width, height), + size=(image.width, image.height), ) if contour is not None: diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index ff87ca3eb720bf015b82c349f3d753caf1421863..3d144d82509555bd1ffb08625437ecf07c9f667c 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -20,7 +20,7 @@ from dan.predict.attention import ( split_text_and_confidences, ) from dan.transforms import get_normalization_transforms, get_preprocessing_transforms -from dan.utils import ind_to_token, read_image +from dan.utils import ind_to_token, list_to_batches, pad_images, read_image class DAN: @@ -248,8 +248,8 @@ class DAN: return out -def process_image( - image_path, +def process_batch( + image_batch, dan_model, device, output, @@ -264,20 +264,25 @@ def process_image( threshold_method, threshold_value, ): - # Load image and pre-process it - image = dan_model.preprocess(str(image_path)) - logger.info("Image loaded.") + input_images, input_sizes = [], [] + logger.info("Loading images...") + for image_path in image_batch: + # Load image and pre-process it + image = dan_model.preprocess(str(image_path)) + input_images.append(image) + input_sizes.append(image.shape[1:]) # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1 - input_tensor = image.unsqueeze(0) - input_tensor = input_tensor.to(device) - input_sizes = [image.shape[1:]] + input_tensor = pad_images(input_images).to(device) + + logger.info("Images preprocessed!") # Parse delimiters to regex word_separators = parse_delimiters(word_separators) line_separators = parse_delimiters(line_separators) # Predict + logger.info("Predicting...") prediction = dan_model.predict( input_tensor, input_sizes, @@ -290,70 +295,78 @@ def process_image( threshold_method=threshold_method, threshold_value=threshold_value, ) + logger.info("Prediction parsing...") + + for idx, image_path in enumerate(image_batch): + predicted_text = prediction["text"][idx] + result = {"text": predicted_text} + + # Return extracted objects (coordinates, text, confidence) + if predict_objects: + result["objects"] = prediction["objects"][idx] + + # Return mean confidence score + if confidence_score: + result["confidences"] = {} + char_confidences = prediction["confidences"][idx] + # retrieve the index of the token ner + index = [ + pos + for pos, char in enumerate(predicted_text) + if char in ["â“", "â“Ÿ", "â““", "â“¡"] + ] - result = {} - result["text"] = prediction["text"][0] - - # Return extracted objects (coordinates, text, confidence) - if predict_objects: - result["objects"] = prediction["objects"][0] - - # Return mean confidence score - if confidence_score: - result["confidences"] = {} - char_confidences = prediction["confidences"][0] - text = result["text"] - # retrieve the index of the token ner - index = [pos for pos, char in enumerate(text) if char in ["â“", "â“Ÿ", "â““", "â“¡"]] - - # calculates scores by token - - result["confidences"]["by ner token"] = [ - { - "text": f"{text[current: next_token]}".replace("\n", " "), - "confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token]), 2)}", - } - # We go up to -1 so that the last token matches until the end of the text - for current, next_token in pairwise(index + [-1]) - ] - result["confidences"]["total"] = np.around(np.mean(char_confidences), 2) - - for level in confidence_score_levels: - result["confidences"][level] = [] - texts, confidences, _ = split_text_and_confidences( - prediction["text"][0], - char_confidences, - level, - word_separators, - line_separators, - ) + # calculates scores by token - for text, conf in zip(texts, confidences): - result["confidences"][level].append({"text": text, "confidence": conf}) - - # Save gif with attention map - if attention_map: - gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif" - logger.info(f"Creating attention GIF in {gif_filename}") - # this returns polygons but unused for now. - plot_attention( - image=image, - text=prediction["text"][0], - weights=prediction["attentions"][0], - level=attention_map_level, - scale=attention_map_scale, - word_separators=word_separators, - line_separators=line_separators, - display_polygons=predict_objects, - threshold_method=threshold_method, - threshold_value=threshold_value, - outname=gif_filename, - ) - result["attention_gif"] = gif_filename + result["confidences"]["by ner token"] = [ + { + "text": f"{predicted_text[current: next_token]}".replace("\n", " "), + "confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token]), 2)}", + } + # We go up to -1 so that the last token matches until the end of the text + for current, next_token in pairwise(index + [-1]) + ] + result["confidences"]["total"] = np.around(np.mean(char_confidences), 2) + + for level in confidence_score_levels: + result["confidences"][level] = [] + texts, confidences, _ = split_text_and_confidences( + predicted_text, + char_confidences, + level, + word_separators, + line_separators, + ) + + for text, conf in zip(texts, confidences): + result["confidences"][level].append( + {"text": text, "confidence": conf} + ) + + # Save gif with attention map + if attention_map: + attentions = prediction["attentions"][idx] + gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif" + logger.info(f"Creating attention GIF in {gif_filename}") + # this returns polygons but unused for now. + plot_attention( + image=input_tensor[idx], + text=predicted_text, + weights=attentions, + level=attention_map_level, + scale=attention_map_scale, + word_separators=word_separators, + line_separators=line_separators, + display_polygons=predict_objects, + threshold_method=threshold_method, + threshold_value=threshold_value, + outname=gif_filename, + ) + result["attention_gif"] = gif_filename - json_filename = f"{output}/{image_path.stem}.json" - logger.info(f"Saving JSON prediction in {json_filename}") - save_json(Path(json_filename), result) + json_filename = f"{output}/{image_path.stem}.json" + logger.info(f"Saving JSON prediction in {json_filename}") + save_json(Path(json_filename), result) def run( @@ -376,6 +389,7 @@ def run( threshold_value, image_extension, gpu_device, + batch_size, ): """ Predict a single image save the output @@ -395,6 +409,7 @@ def run( :param threshold_method: Thresholding method. Should be in ["otsu", "simple"]. :param threshold_value: Thresholding value to use for the "simple" thresholding method. :param gpu_device: Use a specific GPU if available. + :param batch_size: Size of the batches for prediction. """ # Create output directory if necessary if not os.path.exists(output): @@ -407,9 +422,9 @@ def run( dan_model.load(model, parameters, charset, mode="eval") images = image_dir.rglob(f"*{image_extension}") if not image else [image] - for image_name in images: - process_image( - image_name, + for image_batch in list_to_batches(images, n=batch_size): + process_batch( + image_batch, dan_model, device, output, diff --git a/dan/utils.py b/dan/utils.py index 1ec9a8cfcf4207f16894196fe0e9b6c2b22ea82f..9eddb9c561ceac19baab53cf78c33cf768e1e917 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from itertools import islice + import torch import torchvision.io as torchvision @@ -72,3 +74,13 @@ def ind_to_token(labels, ind, oov_symbol=None): else: res = [labels[i] for i in ind] return "".join(res) + + +def list_to_batches(iterable, n): + "Batch data into tuples of length n. The last batch may be shorter." + # list_to_batches('ABCDEFG', 3) --> ABC DEF G + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + while batch := tuple(islice(it, n)): + yield batch diff --git a/tests/test_prediction.py b/tests/test_prediction.py index db87cdb451d6ef7ea80ceb1afb8b5a698836d552..9d7409c4e185c830b805fadaee0866c13ea4debf 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import json +import shutil import pytest @@ -270,9 +271,172 @@ def test_run_prediction( threshold_value=0, image_extension=None, gpu_device=None, + batch_size=1, ) - with (tmp_path / image_name).with_suffix(".json").open("r") as json_file: - prediction = json.load(json_file) - + prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text()) assert prediction == expected_prediction + + +@pytest.mark.parametrize( + "image_names, confidence_score, temperature, expected_predictions", + ( + ( + ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"], + None, + 1.0, + [{"text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241"}], + ), + ( + ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"], + ["word"], + 1.0, + [ + { + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "confidences": { + "by ner token": [], + "total": 1.0, + "word": [ + {"text": "ⓈBellisson", "confidence": 1.0}, + {"text": "â’»Georges", "confidence": 1.0}, + {"text": "â’·91", "confidence": 1.0}, + {"text": "â“P", "confidence": 1.0}, + {"text": "â’¸M", "confidence": 1.0}, + {"text": "â“€Ch", "confidence": 1.0}, + {"text": "â“„Plombier", "confidence": 1.0}, + {"text": "â“…Patron?12241", "confidence": 1.0}, + ], + }, + } + ], + ), + ( + [ + "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", + "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", + ], + ["word"], + 1.0, + [ + { + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "confidences": { + "by ner token": [], + "total": 1.0, + "word": [ + {"text": "ⓈBellisson", "confidence": 1.0}, + {"text": "â’»Georges", "confidence": 1.0}, + {"text": "â’·91", "confidence": 1.0}, + {"text": "â“P", "confidence": 1.0}, + {"text": "â’¸M", "confidence": 1.0}, + {"text": "â“€Ch", "confidence": 1.0}, + {"text": "â“„Plombier", "confidence": 1.0}, + {"text": "â“…Patron?12241", "confidence": 1.0}, + ], + }, + }, + { + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "confidences": { + "by ner token": [], + "total": 1.0, + "word": [ + {"text": "ⓈBellisson", "confidence": 1.0}, + {"text": "â’»Georges", "confidence": 1.0}, + {"text": "â’·91", "confidence": 1.0}, + {"text": "â“P", "confidence": 1.0}, + {"text": "â’¸M", "confidence": 1.0}, + {"text": "â“€Ch", "confidence": 1.0}, + {"text": "â“„Plombier", "confidence": 1.0}, + {"text": "â“…Patron?12241", "confidence": 1.0}, + ], + }, + }, + ], + ), + ( + ["0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84"], + ["word"], + 1.0, + [ + { + "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", + "confidences": { + "by ner token": [], + "total": 1.0, + "word": [ + {"text": "ⓈBellisson", "confidence": 1.0}, + {"text": "â’»Georges", "confidence": 1.0}, + {"text": "â’·91", "confidence": 1.0}, + {"text": "â“P", "confidence": 1.0}, + {"text": "â’¸M", "confidence": 1.0}, + {"text": "â“€Ch", "confidence": 1.0}, + {"text": "â“„Plombier", "confidence": 1.0}, + {"text": "â“…Patron?12241", "confidence": 1.0}, + ], + }, + } + ], + ), + ( + [ + "2c242f5c-e979-43c4-b6f2-a6d4815b651d", + "ffdec445-7f14-4f5f-be44-68d0844d0df1", + ], + False, + 1.0, + [ + {"text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31"}, + {"text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère"}, + ], + ), + ), +) +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_run_prediction_batch( + image_names, + confidence_score, + temperature, + expected_predictions, + prediction_data_path, + batch_size, + tmp_path, +): + # Make tmpdir and copy needed images inside + image_dir = tmp_path / "images" + image_dir.mkdir() + for image_name in image_names: + shutil.copyfile( + (prediction_data_path / "images" / image_name).with_suffix(".png"), + (image_dir / image_name).with_suffix(".png"), + ) + + run_prediction( + image=None, + image_dir=image_dir, + model=prediction_data_path / "popp_line_model.pt", + parameters=prediction_data_path / "parameters.yml", + charset=prediction_data_path / "charset.pkl", + output=tmp_path, + confidence_score=True if confidence_score else False, + confidence_score_levels=confidence_score if confidence_score else [], + attention_map=False, + attention_map_level=None, + attention_map_scale=0.5, + word_separators=[" ", "\n"], + line_separators=["\n"], + temperature=temperature, + predict_objects=False, + threshold_method="otsu", + threshold_value=0, + image_extension=".png", + gpu_device=None, + batch_size=batch_size, + ) + + for image_name, expected_prediction in zip(image_names, expected_predictions): + prediction = json.loads( + (tmp_path / image_name).with_suffix(".json").read_text() + ) + assert prediction == expected_prediction