diff --git a/README.md b/README.md index b91da1154446be1b238f605a8a122fed0fc59f32..5aa38bfde4c879b02f3ab458ef4fa8252d71c807 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,9 @@ Then one can initialize and load the trained model with the parameters used duri - a `parameters.yml` file corresponding to the `inference_parameters.yml` file generated during training. ```python -model_path = "models" +from pathlib import Path + +model_path = Path("models") model = DAN("cpu") model.load(model_path, mode="eval") @@ -33,7 +35,24 @@ model.load(model_path, mode="eval") To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In the end, one can run the prediction: ```python -text, confidence_scores = model.predict(image, confidences=True) +from pathlib import Path +from dan.utils import parse_charset_pattern + +# Load image +image_path = "images/page.jpg" +_, image = dan_model.preprocess(str(image_path)) + +input_tensor = image.unsqueeze(0) +input_tensor = input_tensor.to("cpu") +input_sizes = [image.shape[1:]] + +# Predict +text, confidence_scores = model.predict( + input_tensor, + input_sizes, + char_separators=parse_charset_pattern(dan_model.charset), + confidences=True, +) ``` ## Training diff --git a/dan/ocr/predict/attention.py b/dan/ocr/predict/attention.py index 0aa32b25ad24694e9baf2b5914c7cba0ea7e451d..d12ad81dc0d7e08fee7599fa02a9ed8487e16c8e 100644 --- a/dan/ocr/predict/attention.py +++ b/dan/ocr/predict/attention.py @@ -5,7 +5,7 @@ import logging import re from enum import Enum -from typing import Dict, List, Tuple +from typing import List, Tuple import cv2 import matplotlib.pyplot as plt @@ -14,8 +14,6 @@ import torch from PIL import Image from torchvision.transforms.functional import to_pil_image -from dan.utils import EntityType - logger = logging.getLogger(__name__) @@ -440,10 +438,11 @@ def plot_attention( outname: str, alpha_factor: float, color_map: str, + char_separators: re.Pattern, max_object_height: int = 50, word_separators: re.Pattern = parse_delimiters(["\n", " "]), line_separators: re.Pattern = parse_delimiters(["\n"]), - tokens: Dict[str, EntityType] = {}, + tokens_separators: re.Pattern | None = None, display_polygons: bool = False, ) -> None: """ @@ -456,10 +455,11 @@ def plot_attention( :param outname: Name of the gif image :param alpha_factor: Alpha factor that controls how much the attention map is shown to the user during prediction. (higher value means more transparency for the attention map, commonly between 0.5 and 1.0) :param color_map: Colormap to use for the attention map + :param char_separators: Pattern used to find tokens of the charset :param max_object_height: Maximum height of predicted objects. - :param word_separators: List of word separators - :param line_separators: List of line separators - :param tokens: NER tokens used + :param word_separators: Pattern used to find words + :param line_separators: Pattern used to find lines + :param tokens_separators: Pattern used to find NER entities :param display_polygons: Whether to plot extracted polygons """ image = to_pil_image(image) @@ -467,7 +467,12 @@ def plot_attention( # Split text into characters, words or lines text_list, offsets = split_text( - text, level, word_separators, line_separators, tokens + text, + level, + char_separators, + word_separators, + line_separators, + tokens_separators, ) # Iterate on characters, words or lines diff --git a/dan/ocr/predict/inference.py b/dan/ocr/predict/inference.py index 943d02cef254f0c388f9e3c55f1dc23b97b9e837..4ca3dc52623620c62082afc37490b65ae125ebae 100644 --- a/dan/ocr/predict/inference.py +++ b/dan/ocr/predict/inference.py @@ -170,13 +170,14 @@ class DAN: self, input_tensor: torch.Tensor, input_sizes: List[torch.Size], + char_separators: re.Pattern, confidences: bool = False, attentions: bool = False, attention_level: Level = Level.Line, extract_objects: bool = False, word_separators: re.Pattern = parse_delimiters(["\n", " "]), line_separators: re.Pattern = parse_delimiters(["\n"]), - tokens: Dict[str, EntityType] = {}, + tokens_separators: re.Pattern | None = None, start_token: str | None = None, max_object_height: int = 50, ) -> dict: @@ -184,10 +185,15 @@ class DAN: Run prediction on an input image. :param input_tensor: A batch of images to predict. :param input_sizes: The original images sizes. + :param char_separators: The regular expression pattern to split characters. :param confidences: Return the characters probabilities. :param attentions: Return characters attention weights. :param attention_level: Level of text pieces (must be in [char, word, line, ner]) :param extract_objects: Whether to extract polygons' coordinates. + :param word_separators: The regular expression pattern to split words. + :param line_separators: The regular expression pattern to split lines. + :param tokens_separators: The regular expression pattern to split NER tokens. + :param start_token: The starting token for the prediction. :param max_object_height: Maximum height of predicted objects. """ input_tensor = input_tensor.to(self.device) @@ -320,9 +326,10 @@ class DAN: input_sizes[i][0], input_sizes[i][1], max_object_height=max_object_height, + char_separators=char_separators, word_separators=word_separators, line_separators=line_separators, - tokens=tokens, + tokens_separators=tokens_separators, ) for i in range(batch_size) ] @@ -378,9 +385,10 @@ def process_batch( attentions=attention_map, attention_level=attention_map_level, extract_objects=predict_objects, + char_separators=char_separators, word_separators=word_separators, line_separators=line_separators, - tokens=tokens, + tokens_separators=ner_separators, max_object_height=max_object_height, start_token=start_token, ) @@ -427,7 +435,9 @@ def process_batch( # Save gif with attention map if attention_map: attentions = prediction["attentions"][idx] - gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif" + gif_filename = ( + f"{output}/{image_path.stem}_{attention_map_level.value}.gif" + ) logger.info(f"Creating attention GIF in {gif_filename}") plot_attention( image=visu_tensor[idx], @@ -437,6 +447,7 @@ def process_batch( scale=attention_map_scale, alpha_factor=alpha_factor, color_map=color_map, + char_separators=char_separators, word_separators=word_separators, line_separators=line_separators, tokens_separators=ner_separators, @@ -482,13 +493,18 @@ def run( :param model: Path to the directory containing the model, the YAML parameters file and the charset file to use for prediction. :param output: Path to the output folder where the results will be saved. :param confidence_score: Whether to compute confidence score. + :param confidence_score_levels: Levels of objects to extract. :param attention_map: Whether to plot the attention map. :param attention_map_level: Level of objects to extract. :param attention_map_scale: Scaling factor for the attention map. + :param alpha_factor: Alpha factor for the attention map. + :param color_map: A matplotlib colormap to use for the attention maps. :param word_separators: List of word separators. :param line_separators: List of line separators. + :param temperature: Temperature scalar parameter. :param predict_objects: Whether to extract objects. :param max_object_height: Maximum height of predicted objects. + :param image_extension: Extension of the images to predict. :param gpu_device: Use a specific GPU if available. :param batch_size: Size of the batches for prediction. :param tokens: NER tokens used. diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 189e9eeb6984f4bf1406b4b96d63fe199b0bc34a..36365097455199951d0fe0437782e0f3d32470a1 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -5,6 +5,7 @@ import json import shutil +from pathlib import Path import numpy as np import pytest @@ -13,7 +14,11 @@ import yaml from dan.ocr.predict.attention import Level from dan.ocr.predict.inference import DAN from dan.ocr.predict.inference import run as run_prediction -from dan.utils import parse_tokens, read_yaml +from dan.utils import ( + parse_charset_pattern, + parse_tokens, + read_yaml, +) from tests import FIXTURES PREDICTION_DATA_PATH = FIXTURES / "prediction" @@ -73,18 +78,23 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path): input_tensor = input_tensor.to(device) input_sizes = [image.shape[1:]] - prediction = dan_model.predict(input_tensor, input_sizes) + prediction = dan_model.predict( + input_tensor, + input_sizes, + char_separators=parse_charset_pattern(dan_model.charset), + ) assert prediction == expected_prediction @pytest.mark.parametrize( - "image_name, confidence_score, temperature, expected_prediction", + "image_name, confidence_score, temperature, predict_objects, expected_prediction", ( ( "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", - None, - 1.0, + [], # Confidence score + 1.0, # Temperature + False, # Predict objects { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", "language_model": {}, @@ -93,8 +103,9 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path): ), ( "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", - [Level.Word], - 1.0, + [Level.Word], # Confidence score + 1.0, # Temperature + True, # Predict objects { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", "language_model": {}, @@ -111,12 +122,64 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path): {"text": "â“…Patron?12241", "confidence": 1.0}, ], }, + "objects": [ + { + "confidence": 0.42, + "polygon": [[0, 0], [144, 0], [144, 66], [0, 66]], + "text": "ⓈBellisson", + "text_confidence": 1.0, + }, + { + "confidence": 0.52, + "polygon": [[184, 0], [269, 0], [269, 66], [184, 66]], + "text": "â’»Georges", + "text_confidence": 1.0, + }, + { + "confidence": 0.21, + "polygon": [[294, 0], [371, 0], [371, 66], [294, 66]], + "text": "â’·91", + "text_confidence": 1.0, + }, + { + "confidence": 0.23, + "polygon": [[367, 0], [427, 0], [427, 66], [367, 66]], + "text": "â“P", + "text_confidence": 1.0, + }, + { + "confidence": 0.18, + "polygon": [[535, 0], [619, 0], [619, 66], [535, 66]], + "text": "â’¸M", + "text_confidence": 1.0, + }, + { + "confidence": 0.23, + "polygon": [[589, 0], [674, 0], [674, 66], [589, 66]], + "text": "â“€Ch", + "text_confidence": 1.0, + }, + { + "confidence": 0.31, + "polygon": [[685, 0], [806, 0], [806, 66], [685, 66]], + "text": "â“„Plombier", + "text_confidence": 1.0, + }, + { + "confidence": 0.91, + "polygon": [[820, 0], [938, 0], [938, 66], [820, 66]], + "text": "â“…Patron?12241", + "text_confidence": 1.0, + }, + ], + "attention_gif": "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84_word.gif", }, ), ( "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", - [Level.NER, Level.Word], - 3.5, + [Level.NER, Level.Word], # Confidence score + 3.5, # Temperature + False, # Predict objects { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", "language_model": {}, @@ -147,8 +210,9 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path): ), ( "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", - [Level.Line], - 1.0, + [Level.Line], # Confidence score + 1.0, # Temperature + False, # Predict objects { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", "language_model": {}, @@ -165,8 +229,9 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path): ), ( "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84", - [Level.NER, Level.Line], - 3.5, + [Level.NER, Level.Line], # Confidence score + 3.5, # Temperature + False, # Predict objects { "text": "ⓈBellisson â’»Georges â’·91 â“P â’¸M â“€Ch â“„Plombier â“…Patron?12241", "language_model": {}, @@ -193,8 +258,9 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path): ), ( "0dfe8bcd-ed0b-453e-bf19-cc697012296e", - None, - 1.0, + [], # Confidence score + 1.0, # Temperature + False, # Predict objects { "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", "language_model": {}, @@ -203,8 +269,9 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path): ), ( "0dfe8bcd-ed0b-453e-bf19-cc697012296e", - [Level.NER, Level.Char, Level.Word, Level.Line], - 1.0, + [Level.NER, Level.Char, Level.Word, Level.Line], # Confidence score + 1.0, # Temperature + False, # Predict objects { "text": "ⓈTemplié â’»Marcelle â’·93 â“S â“€ch â“„E dactylo â“…18376", "language_model": {}, @@ -289,8 +356,9 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path): ), ( "2c242f5c-e979-43c4-b6f2-a6d4815b651d", - False, - 1.0, + [], # Confidence score + 1.0, # Temperature + False, # Predict objects { "text": "Ⓢd â’»Charles â’·11 â“P â’¸C â“€F â“„d â“…14 31", "language_model": {}, @@ -299,12 +367,21 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path): ), ( "ffdec445-7f14-4f5f-be44-68d0844d0df1", - False, - 1.0, + [], # Confidence score + 1.0, # Temperature + True, # Predict objects { "text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", "language_model": {}, "confidences": {}, + "objects": [ + { + "confidence": 0.96, + "polygon": [[546, 0], [715, 0], [715, 67], [546, 67]], + "text": "ⓈNaudin â’»Marie â’·53 â“S â’¸v â“€Belle mère", + "text_confidence": 1.0, + } + ], }, ), ), @@ -313,9 +390,15 @@ def test_run_prediction( image_name, confidence_score, temperature, + predict_objects, expected_prediction, tmp_path, ): + if "attention_gif" in expected_prediction: + expected_prediction["attention_gif"] = str( + tmp_path / expected_prediction["attention_gif"] + ) + # Make tmpdir and copy needed image inside image_dir = tmp_path / "images" image_dir.mkdir() @@ -328,17 +411,17 @@ def test_run_prediction( image_dir=image_dir, model=PREDICTION_DATA_PATH, output=tmp_path, - confidence_score=True if confidence_score else False, - confidence_score_levels=confidence_score if confidence_score else [], - attention_map=False, - attention_map_level=None, + confidence_score=bool(confidence_score), + confidence_score_levels=confidence_score, + attention_map=predict_objects and confidence_score, + attention_map_level=[Level.Line, *confidence_score].pop(), attention_map_scale=0.5, alpha_factor=0.9, color_map="nipy_spectral", word_separators=[" ", "\n"], line_separators=["\n"], temperature=temperature, - predict_objects=False, + predict_objects=predict_objects, max_object_height=None, image_extension=".png", gpu_device=None, @@ -352,6 +435,8 @@ def test_run_prediction( prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text()) assert prediction == expected_prediction + if "attention_gif" in expected_prediction: + assert Path(expected_prediction["attention_gif"]).exists() @pytest.mark.parametrize(