diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..371419be9f4dd6c0cbd4afa21f8b9dc9a99da4ba --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*gif filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index e0db1ef23775ca6b120700ca3dcabd6df7bdb047..3f50eb5258f7f61ae28cf975d31cde5bad98e599 100644 --- a/README.md +++ b/README.md @@ -56,3 +56,6 @@ See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/train/) on th ### Synthetic data generation See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/generate/) on the official DAN documentation. + +### Model prediction +See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/predict/) on the official DAN documentation. diff --git a/dan/cli.py b/dan/cli.py index ddf244f4190c0c7b4a12236df7223ab950d3e566..c68a119447d4eab41a92fa222e05e36b9e613794 100644 --- a/dan/cli.py +++ b/dan/cli.py @@ -5,6 +5,7 @@ import errno from dan.datasets import add_dataset_parser from dan.ocr import add_train_parser from dan.ocr.line import add_generate_parser +from dan.predict import add_predict_parser def get_parser(): @@ -14,6 +15,7 @@ def get_parser(): add_dataset_parser(subcommands) add_train_parser(subcommands) add_generate_parser(subcommands) + add_predict_parser(subcommands) return parser diff --git a/dan/predict/__init__.py b/dan/predict/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92aa0f15da0ff46bc33902f6bcf4506da8d38b5d --- /dev/null +++ b/dan/predict/__init__.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +""" +Predict on an image using a trained DAN model. +""" + +import pathlib + +from dan.predict.prediction import run + + +def add_predict_parser(subcommands) -> None: + parser = subcommands.add_parser( + "predict", + description=__doc__, + help=__doc__, + ) + + # Required arguments. + parser.add_argument( + "--image", + type=pathlib.Path, + help="Path to the image to predict.", + required=True, + ) + parser.add_argument( + "--model", + type=pathlib.Path, + help="Path to the model to use for prediction.", + required=True, + ) + parser.add_argument( + "--parameters", + type=pathlib.Path, + help="Path to the YAML parameters file.", + required=True, + default="page", + ) + parser.add_argument( + "--charset", + type=pathlib.Path, + help="Path to the charset file.", + required=True, + ) + parser.add_argument( + "--output", + type=pathlib.Path, + help="Path to the output folder.", + required=True, + ) + # Optional arguments. + parser.add_argument( + "--scale", + type=float, + default=1.0, + required=False, + help="Image scaling factor before feeding it to DAN", + ) + parser.add_argument( + "--confidence-score", + action="store_true", + help="Whether to return confidence scores.", + required=False, + ) + parser.add_argument( + "--attention-map", + action="store_true", + help="Whether to plot attention maps.", + required=False, + ) + parser.add_argument( + "--attention-map-level", + type=str, + choices=["line", "word", "char"], + default="line", + help="Level of attention maps.", + required=False, + ) + parser.add_argument( + "--attention-map-scale", + type=float, + default=0.5, + help="Image scaling factor before creating the GIF", + required=False, + ) + + parser.set_defaults(func=run) diff --git a/dan/predict/attention.py b/dan/predict/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..bdfe57a1f736bc936be78179d8e52270a84911bf --- /dev/null +++ b/dan/predict/attention.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +import cv2 +import numpy as np +from PIL import Image + +from dan import logger + + +def split_text(text, level): + """ + Split text into a list of characters, word, or lines. + :param text: Text prediction from DAN + :param level: Level to visualize (char, word, line) + """ + # split into characters + if level == "char": + text_split = list(text) + offset = 0 + # split into words + elif level == "word": + text = text.replace("\n", " ") + text_split = text.split(" ") + offset = 1 + # split into lines + elif level == "line": + text_split = text.split("\n") + offset = 1 + else: + logger.error("Level should be either 'char', 'word', or 'line'") + return text_split, offset + + +def plot_attention(image, text, weights, level, scale, outname): + """ + Create a gif by blending attention maps to the image for each text piece (char, word or line) + :param image: Input image in PIL format + :param text: Text predicted by DAN + :param weights: Attention weights of size (n_char, feature_height, feature_width) + :param level: Level to display (must be in [char, word, line]) + :param scale: Scaling factor for the output gif image + :param outname: Name of the gif image + """ + height, width, _ = image.shape + attention_map = [] + + # Convert to PIL Image and create mask + mask = Image.new("L", (width, height), color=(110)) + image = Image.fromarray(image) + + # Split text into characters, words or lines + text_list, offset = split_text(text, level) + + # Iterate on characters, words or lines + tot_len = 0 + + max_value = weights.sum(0).max() + for text_piece in text_list: + # blank vector to accumulate weights for the current word/line + coverage_vector = np.zeros((height, width)) + for i in range(len(text_piece)): + local_weight = weights[i + tot_len] + local_weight = cv2.resize(local_weight, (width, height)) + coverage_vector = np.clip(coverage_vector + local_weight, 0, 1) + + # Keep track of text length + tot_len += len(text_piece) + offset + + # Normalize coverage vector + coverage_vector = (coverage_vector / max_value * 255).astype(np.uint8) + + # Blend coverage vector with original image + blank_array = np.zeros((height, width)).astype(np.uint8) + coverage_vector = Image.fromarray( + np.stack([coverage_vector, blank_array, blank_array], axis=2), "RGB" + ) + blend = Image.composite(image, coverage_vector, mask) + + # Resize to save time + blend = blend.resize((int(width * scale), int(height * scale)), Image.ANTIALIAS) + attention_map.append(blend) + + attention_map[0].save( + outname, + save_all=True, + format="GIF", + append_images=attention_map[1:], + duration=1000, + loop=True, + ) diff --git a/dan/predict.py b/dan/predict/prediction.py similarity index 69% rename from dan/predict.py rename to dan/predict/prediction.py index 2f855f72c2031d8ea8daea62ad28384b525f5755..0bb427eaf62893d3ff0589659059d31a4f913392 100644 --- a/dan/predict.py +++ b/dan/predict/prediction.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -import logging +import os import pickle import cv2 @@ -8,9 +8,13 @@ import numpy as np import torch import yaml +from dan import logger +from dan.datasets.extract.utils import save_json from dan.decoder import GlobalHTADecoder from dan.models import FCN_Encoder from dan.ocr.utils import LM_ind_to_str +from dan.predict.attention import plot_attention +from dan.utils import read_image class DAN: @@ -50,7 +54,7 @@ class DAN: decoder = GlobalHTADecoder(parameters["decoder"]).to(self.device) decoder.load_state_dict(checkpoint["decoder_state_dict"], strict=True) - logging.debug(f"Loaded model {model_path}") + logger.debug(f"Loaded model {model_path}") if mode == "train": encoder.train() @@ -81,12 +85,13 @@ class DAN: input_image = (input_image - self.mean) / self.std return input_image - def predict(self, input_tensor, input_sizes, confidences=False): + def predict(self, input_tensor, input_sizes, confidences=False, attentions=False): """ Run prediction on an input image. :param input_tensor: A batch of images to predict. :param input_sizes: The original images sizes. :param confidences: Return the characters probabilities. + :param attentions: Return characters attention weights. """ input_tensor.to(self.device) @@ -105,6 +110,7 @@ class DAN: whole_output = list() confidence_scores = list() + attention_maps = list() cache = None hidden_predict = None @@ -137,6 +143,7 @@ class DAN: num_pred=1, ) whole_output.append(output) + attention_maps.append(weights) confidence_scores.append( torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values ) @@ -161,6 +168,8 @@ class DAN: confidence_scores = ( torch.cat(confidence_scores, dim=1).cpu().detach().numpy() ) + attention_maps = torch.cat(attention_maps, dim=1).cpu().detach().numpy() + predicted_tokens = predicted_tokens[:, 1:] prediction_len[torch.eq(reached_end, False)] = self.max_chars - 1 predicted_tokens = [ @@ -172,8 +181,76 @@ class DAN: predicted_text = [ LM_ind_to_str(self.charset, t, oov_symbol="") for t in predicted_tokens ] - logging.info("Images processed") + logger.info("Images processed") + out = {"text": predicted_text} if confidences: - return predicted_text, confidence_scores - return predicted_text + out["confidences"] = confidence_scores + if attentions: + out["attentions"] = attention_maps + return out + + +def run( + image, + model, + parameters, + charset, + output, + scale, + confidence_score, + attention_map, + attention_map_level, + attention_map_scale, +): + # Create output directory if necessary + if not os.path.exists(output): + os.mkdir(output) + + # Load model + device = "cuda" if torch.cuda.is_available() else "cpu" + dan_model = DAN(device) + dan_model.load(model, parameters, charset, mode="eval") + + # Load image and pre-process it + im = read_image(image, scale=scale) + logger.info("Image loaded.") + im_p = dan_model.preprocess(im) + logger.debug("Image pre-processed.") + + # Convert to tensor of size (batch_size, channel, height, width) with batch_size=1 + input_tensor = torch.tensor(im_p).permute(2, 0, 1).unsqueeze(0) + input_tensor = input_tensor.to(device) + input_sizes = [im.shape[:2]] + + # Predict + prediction = dan_model.predict( + input_tensor, + input_sizes, + confidences=confidence_score, + attentions=attention_map, + ) + result = {"text": prediction["text"][0]} + + # Average character-based confidence scores + if confidence_score: + # TODO: select the level for confidence scores (char, word, line, total) + result["confidence"] = np.around(np.mean(prediction["confidences"][0]), 2) + + # Save gif with attention map + if attention_map: + gif_filename = f"{output}/{image.stem}_{attention_map_level}.gif" + logger.info(f"Creating attention GIF in {gif_filename}") + plot_attention( + image=im, + text=prediction["text"][0], + weights=prediction["attentions"][0], + level=attention_map_level, + scale=attention_map_scale, + outname=gif_filename, + ) + result["attention_gif"] = gif_filename + + json_filename = f"{output}/{image.stem}.json" + logger.info(f"Saving JSON prediction in {json_filename}") + save_json(json_filename, result) diff --git a/dan/utils.py b/dan/utils.py index 50f7311d602c97e80c308637e1b8cb37d8c90f95..2fd1529ac5a5240a2ff97cfd2957e70f7ef22ce1 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -185,3 +185,17 @@ def pad_image_width_random(img, new_width, padding_value, max_pad_left_ratio=1): pad_right = np.ones((h, pad_right, c), dtype=img.dtype) * padding_value img = np.concatenate([pad_left, img, pad_right], axis=1) return img + + +def read_image(filename, scale=1.0): + """ + Read image and rescale it + :param filename: Image path + :param scale: Scaling factor before prediction + """ + image = cv2.cvtColor(cv2.imread(str(filename)), cv2.COLOR_BGR2RGB) + if scale != 1.0: + width = int(image.shape[1] * scale) + height = int(image.shape[0] * scale) + image = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA) + return image diff --git a/docs/assets/example_line.gif b/docs/assets/example_line.gif new file mode 100644 index 0000000000000000000000000000000000000000..1ddd46c336733d44c7a5166ec0ee89a94afd8db2 --- /dev/null +++ b/docs/assets/example_line.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c5d731cbaa7ebc131d7a536db23985f04fd7832808c64b9e62134cce7f99cfe +size 11928807 diff --git a/docs/assets/example_word.gif b/docs/assets/example_word.gif new file mode 100644 index 0000000000000000000000000000000000000000..d988894627d8388ebd4a30251dc1e9c188e2be13 --- /dev/null +++ b/docs/assets/example_word.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f8935185bfb0cf0297b0a2ec501170c37e63d6880f1c14c2c0968e201b4195f +size 80353485 diff --git a/docs/ref/predict.md b/docs/ref/predict.md deleted file mode 100644 index f5295f31154b117b3b6a554ed82962d5a6d3bf1d..0000000000000000000000000000000000000000 --- a/docs/ref/predict.md +++ /dev/null @@ -1,3 +0,0 @@ -# Inference - -::: dan.predict diff --git a/docs/ref/predict/attention.md b/docs/ref/predict/attention.md new file mode 100644 index 0000000000000000000000000000000000000000..1b1eb6c69305824bd576a5dd450cb6d98fd2322c --- /dev/null +++ b/docs/ref/predict/attention.md @@ -0,0 +1,3 @@ +# Attention + +::: dan.predict.attention diff --git a/docs/ref/predict/prediction.md b/docs/ref/predict/prediction.md new file mode 100644 index 0000000000000000000000000000000000000000..c335c66034e77bb05e0c4be5c102b58e48c70f1a --- /dev/null +++ b/docs/ref/predict/prediction.md @@ -0,0 +1,3 @@ +# Inference + +::: dan.predict.prediction diff --git a/docs/usage/index.md b/docs/usage/index.md index b4a369ab2d71ef3adb5053a8529a950f88acada6..3c4d031155930485ebcdd5e6635e6d80df8dd9f8 100644 --- a/docs/usage/index.md +++ b/docs/usage/index.md @@ -10,3 +10,6 @@ When `teklia-dan` is installed in your environment, you may use the following co `teklia-dan generate` : To generate synthetic data to train DAN models. More details in [the dedicated section](./generate.md). + +`teklia-dan predict` +: To predict an image using a trained DAN model. More details in [the dedicated section](./predict.md). diff --git a/docs/usage/predict.md b/docs/usage/predict.md new file mode 100644 index 0000000000000000000000000000000000000000..4c5d905022b8a0b24de6dee3ae59b275d8094b72 --- /dev/null +++ b/docs/usage/predict.md @@ -0,0 +1,110 @@ +# Predict + +## Description + +Use the `teklia-dan predict` command to predict a trained DAN model on an image. + +| Parameter | Description | Type | Default | +| ------------------------------ | ---------------------------------------------------------------------------- | -------- | ------- | +| `--image` | Path to the image to predict. | `Path` | | +| `--model` | Path to the model to use for prediction | `Path` | | +| `--parameters` | Path to the YAML parameters file. | `Path` | | +| `--charset` | Path to the charset file. | `Path` | | +| `--output` | Path to the output folder. Results will be saved in this directory. | `Path` | | +| `--scale` | Image scaling factor before feeding it to DAN. | `float` | 1.0 | +| `--confidence-score` | Whether to return confidence scores. | `bool` | False | +| `--attention-map` | Whether to plot attention maps. | `bool` | False | +| `--attention-map-level` | Level to plot the attention maps. Should be in `["line", "word", "char"]`. | `str` | line | +| `--attention-map-scale` | Image scaling factor before creating the GIF. | `float` | 0.5 | + + +## Examples + +### Predict with confidence scores + +To run a prediction with confidence scores, run this command: +```shell +teklia-dan predict \ + --image dan_humu_page/example.jpg \ + --model dan_humu_page/model.pt \ + --parameters dan_humu_page/parameters.yml \ + --charset dan_humu_page/charset.pkl \ + --output dan_humu_page/predict/ \ + --scale 0.5 \ + --confidence-score +``` +It will create the following JSON file named `dan_humu_page/predict/example.json` + +```json +{ + "text": "Hansteensgt. 2 IV 28/4 - 19\nKj\u00e6re Gerhard.\nTak for Brevet om Boken og Haven\nog Crokus og Blaaveis og tak fordi\nDu vilde be mig derut sammen\nmed Kris og Ragna. Men vet Du\nda ikke, at Kris reiste med sin S\u00f8-\nster Fru Cr\u00f8ger til Lillehammer\nnogle Dage efter Begravelsen? Hen\ndes Address er Amtsingeni\u00f8r\nCr\u00f8ger. Hun skriver at de blir\nder til lidt ut i Mai. Nu er hun\nnoksaa medtat skj\u00f8nner jeg af Sorg\nog af L\u00e6ngsel, skriver saameget r\u00f8-\nrende om Oluf. Ragna har det\nherligt, skriver hun. Hun er bare\ngla, og det vet jeg, at \"Oluf er gla over,\nder hvor han nu er. Jeg har saa in-\nderlig ondt af hende, og om Du skrev\net Par Ord tror jeg det vilde gj\u00f8re\nhende godt. - Jeg gl\u00e6der mig over,\nat Du har skrevet en Bok, og\njeg er vis paa, at den er god.", + "confidence": 0.99 +} +``` + +### Predict with confidence scores and line-level attention maps + +To run a prediction with confidence scores and plot line-level attention maps, run this command: + +```shell +teklia-dan predict \ + --image dan_humu_page/example.jpg \ + --model dan_humu_page/model.pt \ + --parameters dan_humu_page/parameters.yml \ + --charset dan_humu_page/charset.pkl \ + --output dan_humu_page/predict/ \ + --scale 0.5 \ + --confidence-score \ + --attention-map \ +``` + +It will create the following JSON file named `dan_humu_page/predict/example.json` and a GIF showing a word-level attention map `dan_humu_page/predict/example_line.gif` + +```json +{ + "text": "Hansteensgt. 2 IV 28/4 - 19\nKj\u00e6re Gerhard.\nTak for Brevet om Boken og Haven\nog Crokus og Blaaveis og tak fordi\nDu vilde be mig derut sammen\nmed Kris og Ragna. Men vet Du\nda ikke, at Kris reiste med sin S\u00f8-\nster Fru Cr\u00f8ger til Lillehammer\nnogle Dage efter Begravelsen? Hen\ndes Address er Amtsingeni\u00f8r\nCr\u00f8ger. Hun skriver at de blir\nder til lidt ut i Mai. Nu er hun\nnoksaa medtat skj\u00f8nner jeg af Sorg\nog af L\u00e6ngsel, skriver saameget r\u00f8-\nrende om Oluf. Ragna har det\nherligt, skriver hun. Hun er bare\ngla, og det vet jeg, at \"Oluf er gla over,\nder hvor han nu er. Jeg har saa in-\nderlig ondt af hende, og om Du skrev\net Par Ord tror jeg det vilde gj\u00f8re\nhende godt. - Jeg gl\u00e6der mig over,\nat Du har skrevet en Bok, og\njeg er vis paa, at den er god.", + "confidence": 0.99, + "attention_gif": "dan_humu_page/predict/example_line.gif" +} +``` +<video autoplay> + <source src="../assets/example_line.gif"> +</video> + +### Predict with confidence scores and word-level attention maps + +To run a prediction with confidence scores and plot word-level attention maps, run this command: + +```shell +teklia-dan predict \ + --image dan_humu_page/example.jpg \ + --model dan_humu_page/model.pt \ + --parameters dan_humu_page/parameters.yml \ + --charset dan_humu_page/charset.pkl \ + --output dan_humu_page/predict/ \ + --scale 0.5 \ + --confidence-score \ + --attention-map \ + --attention-map-level word \ + --attention-map-scale 0.5 +``` + +It will create the following JSON file named `dan_humu_page/predict/example.json` and a GIF showing a word-level attention map `dan_humu_page/predict/example_word.gif`. + +```json +{ + "text": "Hansteensgt. 2 IV 28/4 - 19\nKj\u00e6re Gerhard.\nTak for Brevet om Boken og Haven\nog Crokus og Blaaveis og tak fordi\nDu vilde be mig derut sammen\nmed Kris og Ragna. Men vet Du\nda ikke, at Kris reiste med sin S\u00f8-\nster Fru Cr\u00f8ger til Lillehammer\nnogle Dage efter Begravelsen? Hen\ndes Address er Amtsingeni\u00f8r\nCr\u00f8ger. Hun skriver at de blir\nder til lidt ut i Mai. Nu er hun\nnoksaa medtat skj\u00f8nner jeg af Sorg\nog af L\u00e6ngsel, skriver saameget r\u00f8-\nrende om Oluf. Ragna har det\nherligt, skriver hun. Hun er bare\ngla, og det vet jeg, at \"Oluf er gla over,\nder hvor han nu er. Jeg har saa in-\nderlig ondt af hende, og om Du skrev\net Par Ord tror jeg det vilde gj\u00f8re\nhende godt. - Jeg gl\u00e6der mig over,\nat Du har skrevet en Bok, og\njeg er vis paa, at den er god.", + "confidence": 0.99, + "attention_gif": "dan_humu_page/predict/example_word.gif" +} +``` +<video autoplay> + <source src="../assets/example_word.gif"> +</video> + +## Remarks + +The script plotting attention maps assumes that: + +* words are separated with the symbol ` ` +* lines are separated with the symbol `\n` diff --git a/mkdocs.yml b/mkdocs.yml index b5787b36f0480aae8e77a8c9c4c3d5eb50c3b3f9..9c5ecb22b162028afbe4af9cd5ab368811e38dec 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -61,6 +61,7 @@ nav: - Dataset formatting: usage/datasets/format.md - Training: usage/train.md - Generate: usage/generate.md + - Predict: usage/predict.md - Documentation development: dev/build_docs.md - Python Reference: - Datasets: @@ -92,11 +93,13 @@ nav: - Model utils: ref/ocr/line/model_utils.md - Training: ref/ocr/line/train.md - Utils: ref/ocr/line/utils.md + - Prediction: + - Inference: ref/predict/prediction.md + - Attention: ref/predict/attention.md - Decoders: ref/decoder.md - Models: ref/models.md - MLflow: ref/mlflow.md - Post Processing: ref/post_processing.md - - Inference: ref/predict.md - Schedulers: ref/schedulers.md - Transformations: ref/transforms.md - Utils: ref/utils.md