Skip to content
Snippets Groups Projects
Commit 9a7e3b6d authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Merge branch 'fix-objects-prediction' into 'main'

Fix objects prediction and GIF generation

Closes #282

See merge request !417
parents fcebc8aa 693ab1cc
No related branches found
No related tags found
1 merge request!417Fix objects prediction and GIF generation
......@@ -24,7 +24,9 @@ Then one can initialize and load the trained model with the parameters used duri
- a `parameters.yml` file corresponding to the `inference_parameters.yml` file generated during training.
```python
model_path = "models"
from pathlib import Path
model_path = Path("models")
model = DAN("cpu")
model.load(model_path, mode="eval")
......@@ -33,7 +35,24 @@ model.load(model_path, mode="eval")
To run the inference on a GPU, one can replace `cpu` by the name of the GPU. In the end, one can run the prediction:
```python
text, confidence_scores = model.predict(image, confidences=True)
from pathlib import Path
from dan.utils import parse_charset_pattern
# Load image
image_path = "images/page.jpg"
_, image = dan_model.preprocess(str(image_path))
input_tensor = image.unsqueeze(0)
input_tensor = input_tensor.to("cpu")
input_sizes = [image.shape[1:]]
# Predict
text, confidence_scores = model.predict(
input_tensor,
input_sizes,
char_separators=parse_charset_pattern(dan_model.charset),
confidences=True,
)
```
## Training
......
......@@ -5,7 +5,7 @@
import logging
import re
from enum import Enum
from typing import Dict, List, Tuple
from typing import List, Tuple
import cv2
import matplotlib.pyplot as plt
......@@ -14,8 +14,6 @@ import torch
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from dan.utils import EntityType
logger = logging.getLogger(__name__)
......@@ -440,10 +438,11 @@ def plot_attention(
outname: str,
alpha_factor: float,
color_map: str,
char_separators: re.Pattern,
max_object_height: int = 50,
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
tokens: Dict[str, EntityType] = {},
tokens_separators: re.Pattern | None = None,
display_polygons: bool = False,
) -> None:
"""
......@@ -456,10 +455,11 @@ def plot_attention(
:param outname: Name of the gif image
:param alpha_factor: Alpha factor that controls how much the attention map is shown to the user during prediction. (higher value means more transparency for the attention map, commonly between 0.5 and 1.0)
:param color_map: Colormap to use for the attention map
:param char_separators: Pattern used to find tokens of the charset
:param max_object_height: Maximum height of predicted objects.
:param word_separators: List of word separators
:param line_separators: List of line separators
:param tokens: NER tokens used
:param word_separators: Pattern used to find words
:param line_separators: Pattern used to find lines
:param tokens_separators: Pattern used to find NER entities
:param display_polygons: Whether to plot extracted polygons
"""
image = to_pil_image(image)
......@@ -467,7 +467,12 @@ def plot_attention(
# Split text into characters, words or lines
text_list, offsets = split_text(
text, level, word_separators, line_separators, tokens
text,
level,
char_separators,
word_separators,
line_separators,
tokens_separators,
)
# Iterate on characters, words or lines
......
......@@ -170,13 +170,14 @@ class DAN:
self,
input_tensor: torch.Tensor,
input_sizes: List[torch.Size],
char_separators: re.Pattern,
confidences: bool = False,
attentions: bool = False,
attention_level: Level = Level.Line,
extract_objects: bool = False,
word_separators: re.Pattern = parse_delimiters(["\n", " "]),
line_separators: re.Pattern = parse_delimiters(["\n"]),
tokens: Dict[str, EntityType] = {},
tokens_separators: re.Pattern | None = None,
start_token: str | None = None,
max_object_height: int = 50,
) -> dict:
......@@ -184,10 +185,15 @@ class DAN:
Run prediction on an input image.
:param input_tensor: A batch of images to predict.
:param input_sizes: The original images sizes.
:param char_separators: The regular expression pattern to split characters.
:param confidences: Return the characters probabilities.
:param attentions: Return characters attention weights.
:param attention_level: Level of text pieces (must be in [char, word, line, ner])
:param extract_objects: Whether to extract polygons' coordinates.
:param word_separators: The regular expression pattern to split words.
:param line_separators: The regular expression pattern to split lines.
:param tokens_separators: The regular expression pattern to split NER tokens.
:param start_token: The starting token for the prediction.
:param max_object_height: Maximum height of predicted objects.
"""
input_tensor = input_tensor.to(self.device)
......@@ -320,9 +326,10 @@ class DAN:
input_sizes[i][0],
input_sizes[i][1],
max_object_height=max_object_height,
char_separators=char_separators,
word_separators=word_separators,
line_separators=line_separators,
tokens=tokens,
tokens_separators=tokens_separators,
)
for i in range(batch_size)
]
......@@ -378,9 +385,10 @@ def process_batch(
attentions=attention_map,
attention_level=attention_map_level,
extract_objects=predict_objects,
char_separators=char_separators,
word_separators=word_separators,
line_separators=line_separators,
tokens=tokens,
tokens_separators=ner_separators,
max_object_height=max_object_height,
start_token=start_token,
)
......@@ -427,7 +435,9 @@ def process_batch(
# Save gif with attention map
if attention_map:
attentions = prediction["attentions"][idx]
gif_filename = f"{output}/{image_path.stem}_{attention_map_level}.gif"
gif_filename = (
f"{output}/{image_path.stem}_{attention_map_level.value}.gif"
)
logger.info(f"Creating attention GIF in {gif_filename}")
plot_attention(
image=visu_tensor[idx],
......@@ -437,6 +447,7 @@ def process_batch(
scale=attention_map_scale,
alpha_factor=alpha_factor,
color_map=color_map,
char_separators=char_separators,
word_separators=word_separators,
line_separators=line_separators,
tokens_separators=ner_separators,
......@@ -482,13 +493,18 @@ def run(
:param model: Path to the directory containing the model, the YAML parameters file and the charset file to use for prediction.
:param output: Path to the output folder where the results will be saved.
:param confidence_score: Whether to compute confidence score.
:param confidence_score_levels: Levels of objects to extract.
:param attention_map: Whether to plot the attention map.
:param attention_map_level: Level of objects to extract.
:param attention_map_scale: Scaling factor for the attention map.
:param alpha_factor: Alpha factor for the attention map.
:param color_map: A matplotlib colormap to use for the attention maps.
:param word_separators: List of word separators.
:param line_separators: List of line separators.
:param temperature: Temperature scalar parameter.
:param predict_objects: Whether to extract objects.
:param max_object_height: Maximum height of predicted objects.
:param image_extension: Extension of the images to predict.
:param gpu_device: Use a specific GPU if available.
:param batch_size: Size of the batches for prediction.
:param tokens: NER tokens used.
......
......@@ -5,6 +5,7 @@
import json
import shutil
from pathlib import Path
import numpy as np
import pytest
......@@ -13,7 +14,11 @@ import yaml
from dan.ocr.predict.attention import Level
from dan.ocr.predict.inference import DAN
from dan.ocr.predict.inference import run as run_prediction
from dan.utils import parse_tokens, read_yaml
from dan.utils import (
parse_charset_pattern,
parse_tokens,
read_yaml,
)
from tests import FIXTURES
PREDICTION_DATA_PATH = FIXTURES / "prediction"
......@@ -73,18 +78,23 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path):
input_tensor = input_tensor.to(device)
input_sizes = [image.shape[1:]]
prediction = dan_model.predict(input_tensor, input_sizes)
prediction = dan_model.predict(
input_tensor,
input_sizes,
char_separators=parse_charset_pattern(dan_model.charset),
)
assert prediction == expected_prediction
@pytest.mark.parametrize(
"image_name, confidence_score, temperature, expected_prediction",
"image_name, confidence_score, temperature, predict_objects, expected_prediction",
(
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
None,
1.0,
[], # Confidence score
1.0, # Temperature
False, # Predict objects
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
......@@ -93,8 +103,9 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path):
),
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
[Level.Word],
1.0,
[Level.Word], # Confidence score
1.0, # Temperature
True, # Predict objects
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
......@@ -111,12 +122,64 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path):
{"text": "ⓅPatron?12241", "confidence": 1.0},
],
},
"objects": [
{
"confidence": 0.42,
"polygon": [[0, 0], [144, 0], [144, 66], [0, 66]],
"text": "ⓈBellisson",
"text_confidence": 1.0,
},
{
"confidence": 0.52,
"polygon": [[184, 0], [269, 0], [269, 66], [184, 66]],
"text": "ⒻGeorges",
"text_confidence": 1.0,
},
{
"confidence": 0.21,
"polygon": [[294, 0], [371, 0], [371, 66], [294, 66]],
"text": "Ⓑ91",
"text_confidence": 1.0,
},
{
"confidence": 0.23,
"polygon": [[367, 0], [427, 0], [427, 66], [367, 66]],
"text": "ⓁP",
"text_confidence": 1.0,
},
{
"confidence": 0.18,
"polygon": [[535, 0], [619, 0], [619, 66], [535, 66]],
"text": "ⒸM",
"text_confidence": 1.0,
},
{
"confidence": 0.23,
"polygon": [[589, 0], [674, 0], [674, 66], [589, 66]],
"text": "ⓀCh",
"text_confidence": 1.0,
},
{
"confidence": 0.31,
"polygon": [[685, 0], [806, 0], [806, 66], [685, 66]],
"text": "ⓄPlombier",
"text_confidence": 1.0,
},
{
"confidence": 0.91,
"polygon": [[820, 0], [938, 0], [938, 66], [820, 66]],
"text": "ⓅPatron?12241",
"text_confidence": 1.0,
},
],
"attention_gif": "0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84_word.gif",
},
),
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
[Level.NER, Level.Word],
3.5,
[Level.NER, Level.Word], # Confidence score
3.5, # Temperature
False, # Predict objects
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
......@@ -147,8 +210,9 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path):
),
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
[Level.Line],
1.0,
[Level.Line], # Confidence score
1.0, # Temperature
False, # Predict objects
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
......@@ -165,8 +229,9 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path):
),
(
"0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84",
[Level.NER, Level.Line],
3.5,
[Level.NER, Level.Line], # Confidence score
3.5, # Temperature
False, # Predict objects
{
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"language_model": {},
......@@ -193,8 +258,9 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path):
),
(
"0dfe8bcd-ed0b-453e-bf19-cc697012296e",
None,
1.0,
[], # Confidence score
1.0, # Temperature
False, # Predict objects
{
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
"language_model": {},
......@@ -203,8 +269,9 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path):
),
(
"0dfe8bcd-ed0b-453e-bf19-cc697012296e",
[Level.NER, Level.Char, Level.Word, Level.Line],
1.0,
[Level.NER, Level.Char, Level.Word, Level.Line], # Confidence score
1.0, # Temperature
False, # Predict objects
{
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
"language_model": {},
......@@ -289,8 +356,9 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path):
),
(
"2c242f5c-e979-43c4-b6f2-a6d4815b651d",
False,
1.0,
[], # Confidence score
1.0, # Temperature
False, # Predict objects
{
"text": "Ⓢd ⒻCharles Ⓑ11 ⓁP ⒸC ⓀF Ⓞd Ⓟ14 31",
"language_model": {},
......@@ -299,12 +367,21 @@ def test_predict(image_name, expected_prediction, normalize, tmp_path):
),
(
"ffdec445-7f14-4f5f-be44-68d0844d0df1",
False,
1.0,
[], # Confidence score
1.0, # Temperature
True, # Predict objects
{
"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
"language_model": {},
"confidences": {},
"objects": [
{
"confidence": 0.96,
"polygon": [[546, 0], [715, 0], [715, 67], [546, 67]],
"text": "ⓈNaudin ⒻMarie Ⓑ53 ⓁS Ⓒv ⓀBelle mère",
"text_confidence": 1.0,
}
],
},
),
),
......@@ -313,9 +390,15 @@ def test_run_prediction(
image_name,
confidence_score,
temperature,
predict_objects,
expected_prediction,
tmp_path,
):
if "attention_gif" in expected_prediction:
expected_prediction["attention_gif"] = str(
tmp_path / expected_prediction["attention_gif"]
)
# Make tmpdir and copy needed image inside
image_dir = tmp_path / "images"
image_dir.mkdir()
......@@ -328,17 +411,17 @@ def test_run_prediction(
image_dir=image_dir,
model=PREDICTION_DATA_PATH,
output=tmp_path,
confidence_score=True if confidence_score else False,
confidence_score_levels=confidence_score if confidence_score else [],
attention_map=False,
attention_map_level=None,
confidence_score=bool(confidence_score),
confidence_score_levels=confidence_score,
attention_map=predict_objects and confidence_score,
attention_map_level=[Level.Line, *confidence_score].pop(),
attention_map_scale=0.5,
alpha_factor=0.9,
color_map="nipy_spectral",
word_separators=[" ", "\n"],
line_separators=["\n"],
temperature=temperature,
predict_objects=False,
predict_objects=predict_objects,
max_object_height=None,
image_extension=".png",
gpu_device=None,
......@@ -352,6 +435,8 @@ def test_run_prediction(
prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text())
assert prediction == expected_prediction
if "attention_gif" in expected_prediction:
assert Path(expected_prediction["attention_gif"]).exists()
@pytest.mark.parametrize(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment