From c5fa54a35efbda482390a2235348bdcdbe8385a4 Mon Sep 17 00:00:00 2001 From: Yoann Schneider <yschneider@teklia.com> Date: Mon, 21 Aug 2023 13:38:49 +0000 Subject: [PATCH] Parse tokens during argparse parsing and expose start_token --- dan/ocr/predict/__init__.py | 14 ++++++++---- dan/ocr/predict/prediction.py | 15 ++++++++----- dan/utils.py | 4 ++-- docs/usage/predict.md | 41 ++++++++++++++++++----------------- tests/test_prediction.py | 7 ++++-- 5 files changed, 48 insertions(+), 33 deletions(-) diff --git a/dan/ocr/predict/__init__.py b/dan/ocr/predict/__init__.py index 1c65254a..0ee92b8b 100644 --- a/dan/ocr/predict/__init__.py +++ b/dan/ocr/predict/__init__.py @@ -6,6 +6,7 @@ Predict on an image using a trained DAN model. import pathlib from dan.ocr.predict.prediction import run +from dan.utils import parse_tokens def add_predict_parser(subcommands) -> None: @@ -51,13 +52,13 @@ def add_predict_parser(subcommands) -> None: help="Path to the output folder.", required=True, ) + # Optional arguments. parser.add_argument( "--tokens", - type=pathlib.Path, - required=True, + type=parse_tokens, + required=False, help="Path to a yaml file containing a mapping between starting tokens and end tokens. Needed for entities.", ) - # Optional arguments. parser.add_argument( "--image-extension", type=str, @@ -154,5 +155,10 @@ def add_predict_parser(subcommands) -> None: default=1, required=False, ) - + parser.add_argument( + "--start-token", + help="Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.", + type=str, + required=False, + ) parser.set_defaults(func=run) diff --git a/dan/ocr/predict/prediction.py b/dan/ocr/predict/prediction.py index 32758f8a..cdca2a90 100644 --- a/dan/ocr/predict/prediction.py +++ b/dan/ocr/predict/prediction.py @@ -5,7 +5,7 @@ import pickle import re from itertools import pairwise from pathlib import Path -from typing import Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple import numpy as np import torch @@ -22,10 +22,10 @@ from dan.ocr.predict.attention import ( ) from dan.ocr.transforms import get_preprocessing_transforms from dan.utils import ( + EntityType, ind_to_token, list_to_batches, pad_images, - parse_tokens, read_image, ) @@ -287,7 +287,8 @@ def process_batch( predict_objects: bool, threshold_method: str, threshold_value: int, - tokens: Path, + tokens: Dict[str, EntityType], + start_token: str, ) -> None: input_images, visu_images, input_sizes = [], [], [] logger.info("Loading images...") @@ -320,6 +321,7 @@ def process_batch( line_separators=line_separators, threshold_method=threshold_method, threshold_value=threshold_value, + start_token=start_token, ) logger.info("Prediction parsing...") @@ -336,7 +338,6 @@ def process_batch( result["confidences"] = {} char_confidences = prediction["confidences"][0] text = result["text"] - tokens = parse_tokens(tokens) start_tokens, end_tokens = zip(*list(tokens.values())) end_tokens = list(filter(bool, end_tokens)) @@ -422,7 +423,8 @@ def run( image_extension: str, gpu_device: int, batch_size: int, - tokens: Path, + tokens: Dict[str, EntityType], + start_token: str, ) -> None: """ Predict a single image save the output @@ -443,6 +445,8 @@ def run( :param threshold_value: Thresholding value to use for the "simple" thresholding method. :param gpu_device: Use a specific GPU if available. :param batch_size: Size of the batches for prediction. + :param tokens: NER tokens used. + :param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages. """ # Create output directory if necessary if not output.exists(): @@ -472,4 +476,5 @@ def run( threshold_method, threshold_value, tokens, + start_token, ) diff --git a/dan/utils.py b/dan/utils.py index 92954f8a..5cb62ff5 100644 --- a/dan/utils.py +++ b/dan/utils.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from itertools import islice from pathlib import Path -from typing import NamedTuple +from typing import Dict, NamedTuple import torch import torchvision.io as torchvision @@ -106,7 +106,7 @@ def list_to_batches(iterable, n): yield batch -def parse_tokens(filename: Path) -> dict: +def parse_tokens(filename: Path) -> Dict[str, EntityType]: return { name: EntityType(**tokens) for name, tokens in yaml.safe_load(filename.read_text()).items() diff --git a/docs/usage/predict.md b/docs/usage/predict.md index 294e67aa..d6d46350 100644 --- a/docs/usage/predict.md +++ b/docs/usage/predict.md @@ -4,26 +4,27 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image. ## Description of parameters -| Parameter | Description | Type | Default | -| --------------------------- | ----------------------------------------------------------------------------------------------- | ------- | ------------- | -| `--image` | Path to the image to predict. Must not be provided with `--image-dir`. | `Path` | | -| `--image-dir` | Path to the folder where the images to predict are stored. Must not be provided with `--image`. | `Path` | | -| `--image-extension` | The extension of the images in the folder. Ignored if `--image-dir` is not provided. | `str` | .jpg | -| `--model` | Path to the model to use for prediction | `Path` | | -| `--parameters` | Path to the YAML parameters file. | `Path` | | -| `--charset` | Path to the charset file. | `Path` | | -| `--output` | Path to the output folder. Results will be saved in this directory. | `Path` | | -| `--confidence-score` | Whether to return confidence scores. | `bool` | `False` | -| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str` | | -| `--attention-map` | Whether to plot attention maps. | `bool` | `False` | -| `--attention-map-scale` | Image scaling factor before creating the GIF. | `float` | `0.5` | -| `--attention-map-level` | Level to plot the attention maps. Should be in `["line", "word", "char"]`. | `str` | `"line"` | -| `--predict-objects` | Whether to return polygons coordinates. | `bool` | `False` | -| `--word-separators` | List of word separators. | `list` | `[" ", "\n"]` | -| `--line-separators` | List of line separators. | `list` | `["\n"]` | -| `--threshold-method` | Method to use for attention mask thresholding. Should be in `["otsu", "simple"]`. | `str` | `"otsu"` | -| `--threshold-value ` | Threshold to use for the "simple" thresholding method. | `int` | `0` | -| `--batch-size ` | Size of the batches for prediction. | `int` | `1` | +| Parameter | Description | Type | Default | +| --------------------------- | --------------------------------------------------------------------------------------------------------------------------- | ------- | ------------- | +| `--image` | Path to the image to predict. Must not be provided with `--image-dir`. | `Path` | | +| `--image-dir` | Path to the folder where the images to predict are stored. Must not be provided with `--image`. | `Path` | | +| `--image-extension` | The extension of the images in the folder. Ignored if `--image-dir` is not provided. | `str` | .jpg | +| `--model` | Path to the model to use for prediction | `Path` | | +| `--parameters` | Path to the YAML parameters file. | `Path` | | +| `--charset` | Path to the charset file. | `Path` | | +| `--output` | Path to the output folder. Results will be saved in this directory. | `Path` | | +| `--confidence-score` | Whether to return confidence scores. | `bool` | `False` | +| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str` | | +| `--attention-map` | Whether to plot attention maps. | `bool` | `False` | +| `--attention-map-scale` | Image scaling factor before creating the GIF. | `float` | `0.5` | +| `--attention-map-level` | Level to plot the attention maps. Should be in `["line", "word", "char"]`. | `str` | `"line"` | +| `--predict-objects` | Whether to return polygons coordinates. | `bool` | `False` | +| `--word-separators` | List of word separators. | `list` | `[" ", "\n"]` | +| `--line-separators` | List of line separators. | `list` | `["\n"]` | +| `--threshold-method` | Method to use for attention mask thresholding. Should be in `["otsu", "simple"]`. | `str` | `"otsu"` | +| `--threshold-value ` | Threshold to use for the "simple" thresholding method. | `int` | `0` | +| `--batch-size ` | Size of the batches for prediction. | `int` | `1` | +| `--start-token ` | Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages. | `str` | `None` | ## Examples diff --git a/tests/test_prediction.py b/tests/test_prediction.py index b9e6e877..3acbb433 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -7,6 +7,7 @@ import pytest from dan.ocr.predict.prediction import DAN from dan.ocr.predict.prediction import run as run_prediction +from dan.utils import parse_tokens @pytest.mark.parametrize( @@ -316,7 +317,8 @@ def test_run_prediction( image_extension=None, gpu_device=None, batch_size=1, - tokens=prediction_data_path / "tokens.yml", + tokens=parse_tokens(prediction_data_path / "tokens.yml"), + start_token=None, ) prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text()) @@ -514,7 +516,8 @@ def test_run_prediction_batch( image_extension=".png", gpu_device=None, batch_size=batch_size, - tokens=prediction_data_path / "tokens.yml", + tokens=parse_tokens(prediction_data_path / "tokens.yml"), + start_token=None, ) for image_name, expected_prediction in zip(image_names, expected_predictions): -- GitLab