Skip to content
Snippets Groups Projects
Commit e2f3d8f3 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Merge branch 'visualisation-predictions-validation-steps' into 'main'

Visualisation of prediction during validation steps

Closes #292

See merge request !435
parents 74cc388a d3fb2c50
No related branches found
No related tags found
1 merge request!435Visualisation of prediction during validation steps
Showing
with 294 additions and 73 deletions
......@@ -93,7 +93,10 @@
"validation": {
"eval_on_valid": true,
"eval_on_valid_interval": 5,
"set_name_focus_metric": "$dataset_name-val"
"set_name_focus_metric": "$dataset_name-val",
"font": "fonts/LinuxLibertine.ttf",
"maximum_font_size": 32,
"n_tensorboard_images": 5
},
"output_folder": "$dataset_path/output",
"max_nb_epochs": 800,
......
......@@ -84,7 +84,10 @@
"validation": {
"eval_on_valid": true,
"eval_on_valid_interval": 2,
"set_name_focus_metric": "training-val"
"set_name_focus_metric": "training-val",
"font": "fonts/LinuxLibertine.ttf",
"maximum_font_size": 32,
"n_tensorboard_images": 1
},
"output_folder": "dan_trained_model",
"gradient_clipping": {},
......
......@@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss
from torch.nn.init import kaiming_uniform_
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import Compose, PILToTensor
from tqdm import tqdm
from dan.ocr.decoder import CTCLanguageDecoder, GlobalHTADecoder
......@@ -30,6 +31,7 @@ from dan.ocr.manager.metrics import Inference, MetricManager
from dan.ocr.manager.ocr import OCRDatasetManager
from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
from dan.ocr.schedulers import DropoutScheduler
from dan.ocr.utils import create_image
from dan.utils import fix_ddp_layers_names, ind_to_token
if MLFLOW_AVAILABLE:
......@@ -56,6 +58,14 @@ class GenericTrainingManager:
self.scaler = None
self.font = self.params["training"]["validation"]["font"]
self.maximum_font_size = self.params["training"]["validation"][
"maximum_font_size"
]
self.n_tensorboard_images = self.params["training"]["validation"][
"n_tensorboard_images"
]
self.optimizers = dict()
self.optimizers_named_params_by_group = dict()
self.lr_schedulers = dict()
......@@ -755,6 +765,7 @@ class GenericTrainingManager:
dataset_name=self.dataset_name,
tokens=self.tokens,
)
with tqdm(total=len(loader.dataset)) as pbar:
pbar.set_description("VALID {} - {}".format(self.latest_epoch, set_name))
with torch.no_grad():
......@@ -777,6 +788,25 @@ class GenericTrainingManager:
pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"]) * self.nb_workers)
if ind_batch < self.n_tensorboard_images:
image = loader.dataset.get_sample_img(ind_batch)
result = create_image(
image,
batch_values["str_x"][0],
self.font,
self.maximum_font_size,
cer=round(batch_metrics["chars_error_rate"][0] * 100, 2),
wer=round(batch_metrics["words_error_rate"][0] * 100, 2),
)
result_tensor = Compose([PILToTensor()])(result)
self.writer.add_image(
f"valid/image_{batch_data['names'][0]}",
result_tensor,
self.latest_epoch,
)
# log metrics in MLflow
logging_metrics(
display_values,
......
......@@ -6,17 +6,16 @@ import logging
import re
from enum import Enum
from pathlib import Path
from statistics import mean
from typing import List, Tuple
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image, ImageDraw
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from dan.ocr.utils import load_font
from dan.ocr.utils import create_image
logger = logging.getLogger(__name__)
......@@ -431,41 +430,6 @@ def get_grid_search_contour(coverage_vector, bin_mask, height=50):
return coord, confidence
def search_font_size(
font: Path,
maximum_font_size: int,
text: str,
width: int,
):
"""
Search the biggest font size compatible with the width of the GIF. Take the maximum font size if it is lesser than the perfect font size.
:param font : Path to the font file to use for the GIF of the attention map.
:param maximum_font_size: Maximum font size to use for the GIF of the attention map.
:param text: Predicted text.
:param width: Image width.
"""
font_size = maximum_font_size
font_param = None
# Check for every font size if it's the perfect font size
while font_param is None:
loaded_font = load_font(font, font_size)
# Get place taken by the font
_, _, right, _ = loaded_font.getbbox(text)
font_param = loaded_font if right < width else None
font_size -= 1
if font_size == 0:
logger.warn("No compatible font size found")
break
return font_param
def plot_attention(
image: torch.Tensor,
text: str,
......@@ -540,7 +504,10 @@ def plot_attention(
)
if contour is not None:
# The image has been scaled so we need to scale the contour
blended = create_image(
blended, text_piece, font, maximum_font_size, contour, scale
)
contour = (contour * scale).astype(np.int32)
if display_polygons:
......@@ -556,33 +523,11 @@ def plot_attention(
# Make the np.array with drawn contours back into a PIL image
blended = Image.fromarray(blended, "RGBA")
# Image size
width, height = blended.size
# Double image size so it have a free with space to write
result = Image.new(image.mode, (width * 2, height), (255, 255, 255))
result.paste(blended, (0, 0))
draw = ImageDraw.Draw(result)
# Search the biggest compatible font size
font_param = search_font_size(font, maximum_font_size, text_piece, width)
if font_param is not None:
# Get the list of every height of every point of the contour
heights = [coord[0][1] for coord in contour.tolist()]
average_height = round(mean(heights))
draw.text(
(width, average_height), text_piece, (0, 0, 0), font=font_param
)
# Keep track of text length
tot_len += len(text_piece) + offset
# Append the blended image to the list of attention maps to be used for the .gif
attention_map.append(result)
attention_map.append(blended)
if not attention_map:
return
......
......@@ -2,11 +2,15 @@
# This code is licensed under CeCILL-C
# -*- coding: utf-8 -*-
import logging
from pathlib import Path
from statistics import mean
from typing import Dict, List, Optional
import torch
from PIL import ImageFont
from numpy import int32, ndarray
from PIL import Image, ImageDraw, ImageFont
from PIL.ImageFont import FreeTypeFont
from prettytable import MARKDOWN, PrettyTable
from torch.optim import Adam
......@@ -14,6 +18,8 @@ from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.transforms import Preprocessing
logger = logging.getLogger(__name__)
METRICS_TABLE_HEADER = {
"cer": "CER (HTR-NER)",
"cer_no_token": "CER (HTR)",
......@@ -114,3 +120,155 @@ def load_font(path: Path, size: int):
:param size: Size of the font.
"""
return ImageFont.truetype(path, size)
def search_font_size(
image: Image,
font: Path,
maximum_font_size: int,
text: str,
width: int | None = None,
height: int | None = None,
) -> int:
"""
Search the biggest font size compatible with the width of the GIF. Take the maximum font size if it is lesser than the perfect font size.
:param image: Image.
:param font: Path to the font file.
:param maximum_font_size: Maximum font size.
:param text: Predicted text.
:param width: Image width.
"""
font_size = maximum_font_size
font_param = None
# Check for every font size if it's the perfect font size
while font_param is None:
loaded_font = load_font(font, font_size)
if width is not None:
# Get place taken by the font
left, _, right, _ = ImageDraw.Draw(image).multiline_textbbox(
(width, 0), text, loaded_font
)
place_taken = right - left
font_param = loaded_font if place_taken < width else None
elif height is not None:
_, top, _, bottom = ImageDraw.Draw(image).multiline_textbbox(
(width, 0), text, loaded_font
)
place_taken = bottom - top
font_param = loaded_font if place_taken < round(height / 10) else None
font_size -= 1
if font_size == 0:
logger.warn("No compatible font size found")
break
return font_param
def search_spacing(
image: Image,
font: FreeTypeFont,
text: str,
width: int,
height: int,
) -> int:
"""
Search the biggest font size compatible with the width of the GIF. Take the maximum font size if it is lesser than the perfect font size.
:param image: Image.
:param font: Parameter of the font.
:param text: Predicted text.
:param width: Image width.
:param height: Image height.
"""
spacing = 50
searched_spacing = None
# Check for every font size if it's the perfect font size
while searched_spacing is None:
# Get place taken by the font
_, _, _, bottom = ImageDraw.Draw(image).multiline_textbbox(
(width, 0), text, font, spacing=spacing
)
searched_spacing = spacing if bottom < height else None
spacing -= 1
if spacing == 0:
logger.warn("No compatible spacing found: font size will be set to 1.")
searched_spacing = 1
return searched_spacing
def create_image(
image: Image,
text: str,
font: Path,
maximum_font_size: int,
contour: ndarray | None = None,
scale: float | None = None,
cer: float | None = None,
wer: float | None = None,
):
"""
Create an image with predicted text.
:param image: Image predicted.
:param text: Text predicted from the image.
:param font: Path to the font file.
:param maximum_font_size: Maximum font size to use.
:param contour: Contour of the predicted text on the image.
:param scale: Scaling factor for the output image.
"""
width, height = image.size
# Double image size so it have a free with space to write
new_image = Image.new(image.mode, (width * 2, height), (255, 255, 255))
new_image.paste(image, (0, 0))
draw = ImageDraw.Draw(new_image)
# Search the biggest compatible font size
font_param = search_font_size(new_image, font, maximum_font_size, text, width)
if font_param is not None and contour is not None:
contour = (contour * scale).astype(int32)
# Get the list of every height of every point of the contour
heights = [coord[0][1] for coord in contour.tolist()]
average_height = round(mean(heights))
draw.text((width, average_height), text, (0, 0, 0), font=font_param)
elif font_param is not None:
spacing = search_spacing(new_image, font_param, text, width, height)
draw.text((width, 0), text, (0, 0, 0), font=font_param, spacing=spacing)
if cer is None or wer is None:
return new_image
more_height = round(height / 10)
new_height = height + more_height
cer_wer_text = f"CER : {cer}% | WER : {wer}%"
# Double image size so it have a free with space to write
result = Image.new(new_image.mode, (width * 2, new_height), (255, 255, 255))
result.paste(new_image, (0, more_height))
draw = ImageDraw.Draw(result)
font_param = search_font_size(result, font, maximum_font_size, cer_wer_text, height)
_, _, right, top = draw.textbbox((0, 0), cer_wer_text, font=font_param)
draw.text(
((width * 2 - right) / 2, (more_height - top) / 2),
cer_wer_text,
(0, 0, 0),
font=font_param,
)
return result
docs/assets/tensorboard/example_scalars_train.png

83.8 KiB

docs/assets/tensorboard/example_scalars_val.png

85.4 KiB

docs/assets/tensorboard/example_val.png

326 KiB

docs/assets/tensorboard/example_val_step_190.png

600 KiB

docs/assets/tensorboard/example_val_step_30.png

660 KiB

......@@ -101,11 +101,14 @@ To train on several GPUs, simply set the `training.device.use_ddp` parameter to
### Validation
| Name | Description | Type | Default |
| -------------------------------------------- | -------------------------------------------------------------------------- | ------ | ------- |
| `training.validation.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` |
| `training.validation.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` |
| `training.validation.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | |
| Name | Description | Type | Default |
| -------------------------------------------- | -------------------------------------------------------------------------- | ------ | -------------------------- |
| `training.validation.eval_on_valid` | Whether to evaluate and log metrics on the validation set during training. | `bool` | `True` |
| `training.validation.eval_on_valid_interval` | Interval (in epochs) to evaluate during training. | `int` | `5` |
| `training.validation.set_name_focus_metric` | Dataset to focus on to select best weights. | `str` | |
| `training.validation.font` | Path to the font used in the image in the tensorboard. | `str` | `fonts/LinuxLibertine.ttf` |
| `training.validation.maximum_font_size` | Maximum size used for the font of the image in the tensorboard. | `int` | |
| `training.validation.n_tensorboard_images` | Number of images in Tensorboard during validation. | `int` | `5` |
During the validation stage, the batch size is set to 1. This avoids problems associated with image sizes that can be very different inside batches and lead to significant padding, resulting in performance degradations.
......
......@@ -6,10 +6,11 @@ To train DAN on your dataset:
1. Create a training JSON configuration file. Refer to the [dedicated page](config.md) for a description of parameters.
1. Run `teklia-dan train --config path/to/your/config.json`.
1. (Optional) Visualize your training in Tensorboard. Refer to the [dedicated page](tensorboard.md).
1. (Optional) Train a language model. Refer to the [dedicated page](language_model.md).
1. Look into the training results in the output folder indicated in your configuration:
- `checkpoints` contains model weights for the last trained epoch and for the epoch giving the best valid CER.
- `results` contains the tensorboard log file and the parameters file.
- `results` contains the [tensorboard](tensorboard.md) log file and the parameters file.
## Additional pages
......
# Tensorboard
DAN relies on Tensorboard to log metrics and predictions. This allows you to monitor the progress of your training.
## Access
To access Tensorboard, run `tensorboard --logdir={output_folder}/results/` in your local terminal.
Then, go on http://localhost:6006 to visualize your training.
## Metrics
Two metrics are commonly used to evaluate Automatic Text Recognition models.
- the Character Error Rate (CER) is the percentage of characters that have been transcribed incorrectly by the model.
- the Word Error Rate (WER) is the percentage of words that have been transcribed incorrectly by the model.
## Usage
Seven metrics are computed on the train and validation set and logged in Tensorboard. In addition, 5 predictions are also logged.
### Training metrics
Several metrics are computed on the training set:
- `train/{dataset}-train_loss_ce`: the cross entropy loss function.
- `train/{dataset}-train_cer`: the CER.
- `train/{dataset}-train_cer_no_token`: the CER ignoring punctuation marks.
- `train/{dataset}-train_ner`: the CER ignoring characters (only NE tokens are considered).
- `train/{dataset}-train_wer`. the WER.
- `train/{dataset}-train_wer_no_punct`: the WER ignoring punctuation marks.
- `train/{dataset}-train_wer_no_token`: the WER ignoring Named Entity (NE) tokens (only characters are considered).
These metrics can be visualized in the `Scalars` tab in Tensorboard, under the `train` section.
<img src="../../../assets/tensorboard/example_scalars_train.png" />
Alternatively, you can find them in the `Time series` tab.
### Validation metrics
The same metrics are computed on the validation set, except for the loss function:
- `val/{dataset}-val_cer`: the CER.
- `val/{dataset}-val_cer_no_token`: the CER ignoring punctuation marks.
- `val/{dataset}-val_ner`: the CER ignoring characters (only NE tokens are considered).
- `val/{dataset}-val_wer`. the WER.
- `val/{dataset}-val_wer_no_punct`: the WER ignoring punctuation marks.
- `val/{dataset}-val_wer_no_token`: the WER ignoring Named Entity (NE) tokens (only characters are considered).
These metrics can be visualized in the `Scalars` tab in Tensorboard, under the `valid` section.
<img src="../../../assets/tensorboard/example_scalars_val.png" />
Alternatively, you can find them in the `Time series` tab.
### Predictions on the validation set
Five validation images are also displayed at each epoch, along with their predicted transcription and CER and WER.
To log more or less images, update the `training.validation.n_tensorboard_images` parameter in the [configuration file](config.md). The font and its size can also be changed.
To visualize them, go in the `Image` tab in Tensorboard.
<img src="../../../assets/tensorboard/example_val.png" />
Select an image to increase its size:
<img src="../../../assets/tensorboard/example_val_step_190.png" />
By default, the slider is set to the last validation step. You can move the cursor to view previous transcriptions on the same image:
<img src="../../../assets/tensorboard/example_val_step_30.png" />
......@@ -71,6 +71,7 @@ nav:
- Configuration: usage/train/config.md
- Data augmentation: usage/train/augmentation.md
- Language model: usage/train/language_model.md
- Tensorboard: usage/train/tensorboard.md
- Jean Zay tutorial: usage/train/jeanzay.md
- Evaluation: usage/evaluate/index.md
- Prediction: usage/predict/index.md
......
tests/data/prediction/0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84_line.gif

130 B | W: | H:

tests/data/prediction/0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84_line.gif

130 B | W: | H:

tests/data/prediction/0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84_line.gif
tests/data/prediction/0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84_line.gif
tests/data/prediction/0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84_line.gif
tests/data/prediction/0a56e8b3-95cd-4fa5-a17b-5b0ff9e6ea84_line.gif
  • 2-up
  • Swipe
  • Onion skin
......@@ -211,6 +211,10 @@ def test_evaluate(
output_json = tmp_path / "inference.json" if is_output_json else None
evaluate_config["training"]["validation"]["font"] = "fonts/LinuxLibertine.ttf"
evaluate_config["training"]["validation"]["maximum_font_size"] = 32
evaluate_config["training"]["validation"]["n_tensorboard_images"] = 5
evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD, output_json=output_json)
if is_output_json:
......@@ -378,6 +382,10 @@ def test_evaluate_language_model(
"weight": language_model_weight,
}
evaluate_config["training"]["validation"]["font"] = "fonts/LinuxLibertine.ttf"
evaluate_config["training"]["validation"]["maximum_font_size"] = 32
evaluate_config["training"]["validation"]["n_tensorboard_images"] = 5
evaluate.run(evaluate_config, evaluate.NERVAL_THRESHOLD, output_json=None)
# Check that the evaluation results are correct
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment