Skip to content
Snippets Groups Projects
Commit c5fa54a3 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Manon Blanco
Browse files

Parse tokens during argparse parsing and expose start_token

parent 78281e2e
No related branches found
No related tags found
1 merge request!254Parse tokens during argparse parsing and expose start_token
......@@ -6,6 +6,7 @@ Predict on an image using a trained DAN model.
import pathlib
from dan.ocr.predict.prediction import run
from dan.utils import parse_tokens
def add_predict_parser(subcommands) -> None:
......@@ -51,13 +52,13 @@ def add_predict_parser(subcommands) -> None:
help="Path to the output folder.",
required=True,
)
# Optional arguments.
parser.add_argument(
"--tokens",
type=pathlib.Path,
required=True,
type=parse_tokens,
required=False,
help="Path to a yaml file containing a mapping between starting tokens and end tokens. Needed for entities.",
)
# Optional arguments.
parser.add_argument(
"--image-extension",
type=str,
......@@ -154,5 +155,10 @@ def add_predict_parser(subcommands) -> None:
default=1,
required=False,
)
parser.add_argument(
"--start-token",
help="Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.",
type=str,
required=False,
)
parser.set_defaults(func=run)
......@@ -5,7 +5,7 @@ import pickle
import re
from itertools import pairwise
from pathlib import Path
from typing import Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple
import numpy as np
import torch
......@@ -22,10 +22,10 @@ from dan.ocr.predict.attention import (
)
from dan.ocr.transforms import get_preprocessing_transforms
from dan.utils import (
EntityType,
ind_to_token,
list_to_batches,
pad_images,
parse_tokens,
read_image,
)
......@@ -287,7 +287,8 @@ def process_batch(
predict_objects: bool,
threshold_method: str,
threshold_value: int,
tokens: Path,
tokens: Dict[str, EntityType],
start_token: str,
) -> None:
input_images, visu_images, input_sizes = [], [], []
logger.info("Loading images...")
......@@ -320,6 +321,7 @@ def process_batch(
line_separators=line_separators,
threshold_method=threshold_method,
threshold_value=threshold_value,
start_token=start_token,
)
logger.info("Prediction parsing...")
......@@ -336,7 +338,6 @@ def process_batch(
result["confidences"] = {}
char_confidences = prediction["confidences"][0]
text = result["text"]
tokens = parse_tokens(tokens)
start_tokens, end_tokens = zip(*list(tokens.values()))
end_tokens = list(filter(bool, end_tokens))
......@@ -422,7 +423,8 @@ def run(
image_extension: str,
gpu_device: int,
batch_size: int,
tokens: Path,
tokens: Dict[str, EntityType],
start_token: str,
) -> None:
"""
Predict a single image save the output
......@@ -443,6 +445,8 @@ def run(
:param threshold_value: Thresholding value to use for the "simple" thresholding method.
:param gpu_device: Use a specific GPU if available.
:param batch_size: Size of the batches for prediction.
:param tokens: NER tokens used.
:param start_token: Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages.
"""
# Create output directory if necessary
if not output.exists():
......@@ -472,4 +476,5 @@ def run(
threshold_method,
threshold_value,
tokens,
start_token,
)
# -*- coding: utf-8 -*-
from itertools import islice
from pathlib import Path
from typing import NamedTuple
from typing import Dict, NamedTuple
import torch
import torchvision.io as torchvision
......@@ -106,7 +106,7 @@ def list_to_batches(iterable, n):
yield batch
def parse_tokens(filename: Path) -> dict:
def parse_tokens(filename: Path) -> Dict[str, EntityType]:
return {
name: EntityType(**tokens)
for name, tokens in yaml.safe_load(filename.read_text()).items()
......
......@@ -4,26 +4,27 @@ Use the `teklia-dan predict` command to apply a trained DAN model on an image.
## Description of parameters
| Parameter | Description | Type | Default |
| --------------------------- | ----------------------------------------------------------------------------------------------- | ------- | ------------- |
| `--image` | Path to the image to predict. Must not be provided with `--image-dir`. | `Path` | |
| `--image-dir` | Path to the folder where the images to predict are stored. Must not be provided with `--image`. | `Path` | |
| `--image-extension` | The extension of the images in the folder. Ignored if `--image-dir` is not provided. | `str` | .jpg |
| `--model` | Path to the model to use for prediction | `Path` | |
| `--parameters` | Path to the YAML parameters file. | `Path` | |
| `--charset` | Path to the charset file. | `Path` | |
| `--output` | Path to the output folder. Results will be saved in this directory. | `Path` | |
| `--confidence-score` | Whether to return confidence scores. | `bool` | `False` |
| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str` | |
| `--attention-map` | Whether to plot attention maps. | `bool` | `False` |
| `--attention-map-scale` | Image scaling factor before creating the GIF. | `float` | `0.5` |
| `--attention-map-level` | Level to plot the attention maps. Should be in `["line", "word", "char"]`. | `str` | `"line"` |
| `--predict-objects` | Whether to return polygons coordinates. | `bool` | `False` |
| `--word-separators` | List of word separators. | `list` | `[" ", "\n"]` |
| `--line-separators` | List of line separators. | `list` | `["\n"]` |
| `--threshold-method` | Method to use for attention mask thresholding. Should be in `["otsu", "simple"]`. | `str` | `"otsu"` |
| `--threshold-value ` | Threshold to use for the "simple" thresholding method. | `int` | `0` |
| `--batch-size ` | Size of the batches for prediction. | `int` | `1` |
| Parameter | Description | Type | Default |
| --------------------------- | --------------------------------------------------------------------------------------------------------------------------- | ------- | ------------- |
| `--image` | Path to the image to predict. Must not be provided with `--image-dir`. | `Path` | |
| `--image-dir` | Path to the folder where the images to predict are stored. Must not be provided with `--image`. | `Path` | |
| `--image-extension` | The extension of the images in the folder. Ignored if `--image-dir` is not provided. | `str` | .jpg |
| `--model` | Path to the model to use for prediction | `Path` | |
| `--parameters` | Path to the YAML parameters file. | `Path` | |
| `--charset` | Path to the charset file. | `Path` | |
| `--output` | Path to the output folder. Results will be saved in this directory. | `Path` | |
| `--confidence-score` | Whether to return confidence scores. | `bool` | `False` |
| `--confidence-score-levels` | Level to return confidence scores. Should be any combination of `["line", "word", "char"]`. | `str` | |
| `--attention-map` | Whether to plot attention maps. | `bool` | `False` |
| `--attention-map-scale` | Image scaling factor before creating the GIF. | `float` | `0.5` |
| `--attention-map-level` | Level to plot the attention maps. Should be in `["line", "word", "char"]`. | `str` | `"line"` |
| `--predict-objects` | Whether to return polygons coordinates. | `bool` | `False` |
| `--word-separators` | List of word separators. | `list` | `[" ", "\n"]` |
| `--line-separators` | List of line separators. | `list` | `["\n"]` |
| `--threshold-method` | Method to use for attention mask thresholding. Should be in `["otsu", "simple"]`. | `str` | `"otsu"` |
| `--threshold-value ` | Threshold to use for the "simple" thresholding method. | `int` | `0` |
| `--batch-size ` | Size of the batches for prediction. | `int` | `1` |
| `--start-token ` | Use a specific starting token at the beginning of the prediction. Useful when making predictions on different single pages. | `str` | `None` |
## Examples
......
......@@ -7,6 +7,7 @@ import pytest
from dan.ocr.predict.prediction import DAN
from dan.ocr.predict.prediction import run as run_prediction
from dan.utils import parse_tokens
@pytest.mark.parametrize(
......@@ -316,7 +317,8 @@ def test_run_prediction(
image_extension=None,
gpu_device=None,
batch_size=1,
tokens=prediction_data_path / "tokens.yml",
tokens=parse_tokens(prediction_data_path / "tokens.yml"),
start_token=None,
)
prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text())
......@@ -514,7 +516,8 @@ def test_run_prediction_batch(
image_extension=".png",
gpu_device=None,
batch_size=batch_size,
tokens=prediction_data_path / "tokens.yml",
tokens=parse_tokens(prediction_data_path / "tokens.yml"),
start_token=None,
)
for image_name, expected_prediction in zip(image_names, expected_predictions):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment