Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (20)
Showing
with 347 additions and 71 deletions
[submodule "nerval"]
path = nerval
url = ../../ner/nerval.git
......@@ -7,7 +7,12 @@ RUN apt-get -y update && \
WORKDIR /src
# Install DAN as a package
# Copy submodule data
COPY nerval nerval
# Copy DAN data
COPY dan dan
COPY requirements.txt *-requirements.txt setup.py VERSION README.md ./
# Install DAN as a package
RUN pip install . --no-cache-dir
MIT License
Copyright (c) 2023 Teklia
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
include LICENSE
include requirements.txt
include doc-requirements.txt
include mlflow-requirements.txt
......
......@@ -4,11 +4,14 @@
For more details about this package, make sure to see the documentation available at <https://atr.pages.teklia.com/dan/>.
This is an open-source project, licensed using [the MIT license](https://opensource.org/license/mit/).
## Development
For development and tests purpose it may be useful to install the project as a editable package with pip.
- Use a virtualenv (e.g. with virtualenvwrapper `mkvirtualenv -a . dan`)
- Initialize the [`Nerval`](https://gitlab.teklia.com/ner/nerval) submodule (e.g. `git submodule update --init --recursive`)
- Install `dan` as a package (e.g. `pip install -e .`)
### Linter
......
{
"dataset": {
"datasets": {
"training": "tests/data/training/training_dataset"
"training": "tests/data/prediction"
},
"train": {
"name": "training-train",
......@@ -19,8 +19,8 @@
["training", "test"]
]
},
"max_char_prediction": 30,
"tokens": null
"max_char_prediction": 200,
"tokens": "tests/data/prediction/tokens.yml"
},
"model": {
"transfered_charset": true,
......@@ -45,7 +45,7 @@
},
"training": {
"data": {
"batch_size": 2,
"batch_size": 1,
"load_in_memory": true,
"worker_per_gpu": 4,
"preprocessings": [
......
......@@ -93,16 +93,18 @@ def add_extract_parser(subcommands) -> None:
)
parser.add_argument(
"--transcription-worker-version",
"--transcription-worker-versions",
type=parse_worker_version,
nargs="+",
help=f"Filter transcriptions by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=[],
)
parser.add_argument(
"--entity-worker-version",
"--entity-worker-versions",
type=parse_worker_version,
nargs="+",
help=f"Filter transcriptions entities by worker_version. Use {MANUAL_SOURCE} for manual filtering.",
required=False,
default=[],
)
parser.add_argument(
......
......@@ -56,8 +56,8 @@ class ArkindexExtractor:
entity_separators: List[str] = ["\n", " "],
unknown_token: str = "",
tokens: Path | None = None,
transcription_worker_version: str | bool | None = None,
entity_worker_version: str | bool | None = None,
transcription_worker_versions: List[str | bool] = [],
entity_worker_versions: List[str | bool] = [],
keep_spaces: bool = False,
allow_empty: bool = False,
subword_vocab_size: int = 1000,
......@@ -68,8 +68,8 @@ class ArkindexExtractor:
self.entity_separators = entity_separators
self.unknown_token = unknown_token
self.tokens = parse_tokens(tokens) if tokens else {}
self.transcription_worker_version = transcription_worker_version
self.entity_worker_version = entity_worker_version
self.transcription_worker_versions = transcription_worker_versions
self.entity_worker_versions = entity_worker_versions
self.allow_empty = allow_empty
self.mapping = LMTokenMapping()
self.keep_spaces = keep_spaces
......@@ -98,7 +98,7 @@ class ArkindexExtractor:
If the entities are needed, they are added to the transcription using tokens.
"""
transcriptions = get_transcriptions(
element.id, self.transcription_worker_version
element.id, self.transcription_worker_versions
)
if len(transcriptions) == 0:
if self.allow_empty:
......@@ -112,7 +112,7 @@ class ArkindexExtractor:
entities = get_transcription_entities(
transcription.id,
self.entity_worker_version,
self.entity_worker_versions,
supported_types=list(self.tokens),
)
......@@ -282,6 +282,10 @@ class ArkindexExtractor:
)
continue
# Extract the train set first to correctly build the `self.charset` variable
splits.remove(TRAIN_NAME)
splits.insert(0, TRAIN_NAME)
# Iterate over the subsets to find the page images and labels.
for split in splits:
with tqdm(
......@@ -315,8 +319,8 @@ def run(
entity_separators: List[str],
unknown_token: str,
tokens: Path,
transcription_worker_version: str | bool | None,
entity_worker_version: str | bool | None,
transcription_worker_versions: List[str | bool],
entity_worker_versions: List[str | bool],
keep_spaces: bool,
allow_empty: bool,
subword_vocab_size: int,
......@@ -334,8 +338,8 @@ def run(
entity_separators=entity_separators,
unknown_token=unknown_token,
tokens=tokens,
transcription_worker_version=transcription_worker_version,
entity_worker_version=entity_worker_version,
transcription_worker_versions=transcription_worker_versions,
entity_worker_versions=entity_worker_versions,
keep_spaces=keep_spaces,
allow_empty=allow_empty,
subword_vocab_size=subword_vocab_size,
......
......@@ -51,18 +51,22 @@ def get_elements(
return query
def build_worker_version_filter(ArkindexModel, worker_version):
def build_worker_version_filter(ArkindexModel, worker_versions: List[str | bool]):
"""
`False` worker version means `manual` worker_version -> null field.
"""
if worker_version:
return ArkindexModel.worker_version == worker_version
else:
return ArkindexModel.worker_version.is_null()
condition = None
for worker_version in worker_versions:
condition |= (
ArkindexModel.worker_version == worker_version
if worker_version
else ArkindexModel.worker_version.is_null()
)
return condition
def get_transcriptions(
element_id: str, transcription_worker_version: str | bool
element_id: str, transcription_worker_versions: List[str | bool]
) -> List[Transcription]:
"""
Retrieve transcriptions from an SQLite export of an Arkindex corpus
......@@ -71,10 +75,10 @@ def get_transcriptions(
Transcription.id, Transcription.text, Transcription.worker_version
).where((Transcription.element == element_id))
if transcription_worker_version is not None:
if transcription_worker_versions:
query = query.where(
build_worker_version_filter(
Transcription, worker_version=transcription_worker_version
Transcription, worker_versions=transcription_worker_versions
)
)
return query
......@@ -82,7 +86,7 @@ def get_transcriptions(
def get_transcription_entities(
transcription_id: str,
entity_worker_version: str | bool | None,
entity_worker_versions: List[str | bool],
supported_types: List[str],
) -> List[TranscriptionEntity]:
"""
......@@ -104,10 +108,10 @@ def get_transcription_entities(
)
)
if entity_worker_version is not None:
if entity_worker_versions:
query = query.where(
build_worker_version_filter(
TranscriptionEntity, worker_version=entity_worker_version
TranscriptionEntity, worker_versions=entity_worker_versions
)
)
......
......@@ -5,17 +5,47 @@ Evaluate a trained DAN model.
import logging
import random
from argparse import ArgumentTypeError
from itertools import chain
from operator import attrgetter
from pathlib import Path
from typing import Dict, List
import numpy as np
import torch
import torch.multiprocessing as mp
from edlib import align, getNiceAlignment
from prettytable import MARKDOWN, PrettyTable
from dan.bio import convert
from dan.ocr.manager.metrics import Inference
from dan.ocr.manager.training import Manager
from dan.ocr.utils import add_metrics_table_row, create_metrics_table, update_config
from dan.utils import read_json
from dan.utils import parse_tokens, read_json
from nerval.evaluate import evaluate
from nerval.parse import parse_bio
from nerval.utils import print_results
logger = logging.getLogger(__name__)
NERVAL_THRESHOLD = 0.30
NB_WORST_PREDICTIONS = 5
def parse_threshold(value: str) -> float:
"""
Check that the string passed as parameter is a correct floating point number between 0 and 1
"""
try:
value = float(value)
except ValueError:
raise ArgumentTypeError("Must be a floating point number.")
if value < 0 or value > 1:
raise ArgumentTypeError("Must be between 0 and 1.")
return value
def add_evaluate_parser(subcommands) -> None:
parser = subcommands.add_parser(
......@@ -31,10 +61,87 @@ def add_evaluate_parser(subcommands) -> None:
help="Configuration file.",
)
parser.add_argument(
"--nerval-threshold",
help="Distance threshold for the match between gold and predicted entity during Nerval evaluation.",
default=NERVAL_THRESHOLD,
type=parse_threshold,
)
parser.set_defaults(func=run)
def eval(rank, config, mlflow_logging):
def print_worst_predictions(all_inferences: Dict[str, List[Inference]]):
table = PrettyTable(
field_names=[
"Image name",
"WER",
"Alignment between ground truth - prediction",
]
)
table.set_style(MARKDOWN)
worst_inferences = sorted(
chain.from_iterable(all_inferences.values()),
key=attrgetter("wer"),
reverse=True,
)[:NB_WORST_PREDICTIONS]
for inference in worst_inferences:
alignment = getNiceAlignment(
align(
inference.ground_truth,
inference.prediction,
task="path",
),
inference.ground_truth,
inference.prediction,
)
alignment_str = f'{alignment["query_aligned"]}\n{alignment["matched_aligned"]}\n{alignment["target_aligned"]}'
table.add_row([inference.image, round(inference.wer * 100, 2), alignment_str])
print(f"\n#### {NB_WORST_PREDICTIONS} worst prediction(s)\n")
print(table)
def eval_nerval(
all_inferences: Dict[str, List[Inference]],
tokens: Path,
threshold: float,
):
print("\n#### Nerval evaluation")
def inferences_to_parsed_bio(attr: str):
bio_values = []
for inference in inferences:
value = getattr(inference, attr)
bio_value = convert(value, ner_tokens=tokens)
bio_values.extend(bio_value.split("\n"))
# Parse this BIO format
return parse_bio(bio_values)
# Evaluate with Nerval
tokens = parse_tokens(tokens)
for split_name, inferences in all_inferences.items():
ground_truths = inferences_to_parsed_bio("ground_truth")
predictions = inferences_to_parsed_bio("prediction")
if not (ground_truths and predictions):
continue
scores = {
key: {
k: round(value * 100, 2) if k in ["P", "R", "F1"] else value
for k, value in values.items()
}
for key, values in evaluate(ground_truths, predictions, threshold).items()
}
print(f"\n##### {split_name}\n")
print_results(scores)
def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
......@@ -62,10 +169,12 @@ def eval(rank, config, mlflow_logging):
metric_names.append("ner")
metrics_table = create_metrics_table(metric_names)
all_inferences = {}
for dataset_name in config["dataset"]["datasets"]:
for set_name in ["train", "val", "test"]:
logger.info(f"Evaluating on set `{set_name}`")
metrics = model.evaluate(
metrics, inferences = model.evaluate(
"{}-{}".format(dataset_name, set_name),
[
(dataset_name, set_name),
......@@ -75,11 +184,22 @@ def eval(rank, config, mlflow_logging):
)
add_metrics_table_row(metrics_table, set_name, metrics)
all_inferences[set_name] = inferences
print("\n#### DAN evaluation\n")
print(metrics_table)
if "ner" in metric_names:
eval_nerval(
all_inferences,
tokens=config["dataset"]["tokens"],
threshold=nerval_threshold,
)
print_worst_predictions(all_inferences)
def run(config: dict):
def run(config: dict, nerval_threshold: float):
update_config(config)
mlflow_logging = bool(config.get("mlflow"))
......@@ -94,8 +214,8 @@ def run(config: dict):
):
mp.spawn(
eval,
args=(config, mlflow_logging),
args=(config, nerval_threshold, mlflow_logging),
nprocs=config["training"]["device"]["nb_gpu"],
)
else:
eval(0, config, mlflow_logging)
eval(0, config, nerval_threshold, mlflow_logging)
......@@ -3,7 +3,7 @@ import re
from collections import defaultdict
from operator import attrgetter
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, NamedTuple
import editdistance
import numpy as np
......@@ -23,6 +23,18 @@ REGEX_ONLY_ONE_SPACE = re.compile(r"\s+")
METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"}
class Inference(NamedTuple):
"""
Store a prediction with its ground truth to avoid
inferring again when we need to compute new metrics
"""
image: str
ground_truth: str
prediction: str
wer: float
class MetricManager:
def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None):
self.dataset_name: str = dataset_name
......
......@@ -4,9 +4,10 @@ import os
import random
from copy import deepcopy
from enum import Enum
from itertools import repeat
from pathlib import Path
from time import time
from typing import Dict
from typing import Dict, List, Tuple
import numpy as np
import torch
......@@ -20,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dan.ocr.manager.metrics import MetricManager
from dan.ocr.manager.metrics import Inference, MetricManager
from dan.ocr.manager.ocr import OCRDatasetManager
from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
from dan.ocr.schedulers import DropoutScheduler
......@@ -750,7 +751,7 @@ class GenericTrainingManager:
def evaluate(
self, custom_name, sets_list, metric_names, mlflow_logging=False
) -> Dict[str, int | float]:
) -> Tuple[Dict[str, int | float], List[Inference]]:
"""
Main loop for evaluation
"""
......@@ -768,6 +769,11 @@ class GenericTrainingManager:
tokens=self.tokens,
)
# Keep inferences in memory to:
# - evaluate with Nerval
# - display worst predictions
inferences = []
with tqdm(total=len(loader.dataset)) as pbar:
pbar.set_description("Evaluation")
with torch.no_grad():
......@@ -792,6 +798,16 @@ class GenericTrainingManager:
pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"]) * self.nb_workers)
inferences.extend(
map(
Inference,
batch_data["names"],
batch_values["str_y"],
batch_values["str_x"],
repeat(display_values["wer"]),
)
)
# log metrics in MLflow
logging_name = custom_name.split("-")[1]
logging_tags_metrics(
......@@ -810,7 +826,7 @@ class GenericTrainingManager:
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
return metrics
return metrics, inferences
def output_pred(self, name):
path = self.paths["results"] / "predict_{}_{}.yaml".format(
......
......@@ -93,7 +93,7 @@ def compute_prob_by_ner(
return zip(
*[
(
f"{characters[current: next_token]}".replace("\n", " "),
characters[current:next_token],
np.mean(probabilities[current:next_token]),
)
for current, next_token in indices
......@@ -154,10 +154,7 @@ def split_text(
return [], []
indices = build_ner_indices(text, tokens)
text_split = [
f"{text[current: next_token]}".replace("\n", " ")
for current, next_token in indices
]
text_split = [text[current:next_token] for current, next_token in indices]
case _:
logger.error(f"Level should be either {list(map(str, Level))}")
return [], []
......@@ -201,6 +198,9 @@ def split_text_and_confidences(
return [], [], []
indices = build_ner_indices(text, tokens)
if not indices:
return [], [], []
texts, confidences = compute_prob_by_ner(text, confidences, indices)
case _:
logger.error(f"Level should be either {list(map(str, Level))}")
......@@ -256,6 +256,8 @@ def get_predicted_polygons_with_confidence(
size=(width, height),
)
start_index += len(text_piece) + offset
if not polygon:
continue
polygon["text"] = text_piece
polygon["text_confidence"] = confidence
polygons.append(polygon)
......@@ -365,7 +367,7 @@ def get_polygon(
weights: np.ndarray,
size: Tuple[int, int] | None = None,
max_object_height: int = 50,
) -> Tuple[dict, np.ndarray]:
) -> Tuple[dict, np.ndarray | None]:
"""
Gets polygon associated with element of current text_piece, indexed by offset
:param text: Text piece selected with offset after splitting DAN prediction
......@@ -385,6 +387,8 @@ def get_polygon(
if max_object_height
else get_best_contour(coverage_vector, bin_mask)
)
if not coord or confidence is None:
return {}, None
# Format for JSON
polygon = {
......@@ -403,7 +407,7 @@ def get_best_contour(coverage_vector, bin_mask):
contours, _ = cv2.findContours(bin_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return {}, None
return [], None
# Select best contour
metrics = [compute_contour_metrics(coverage_vector, cnt) for cnt in contours]
......@@ -420,6 +424,10 @@ def get_grid_search_contour(coverage_vector, bin_mask, height=50):
"""
# Limit search area based on attention values
roi = np.argwhere(bin_mask == 255)
if not np.any(roi):
return [], None
y_min, y_max = roi[:, 0].min(), roi[:, 0].max()
# Limit bounding box shape
......
......@@ -93,6 +93,7 @@ def add_metrics_table_row(
continue
metric_name = REVERSE_HEADER[column]
row.append(metrics.get(metric_name, ""))
metric_value = metrics.get(metric_name)
row.append(round(metric_value * 100, 2) if metric_value is not None else "")
table.add_row(row)
......@@ -26,6 +26,12 @@ To install DAN manually, you need to first clone via:
git clone git@gitlab.teklia.com:atr/dan.git
```
Then you can initialize the [`Nerval`](https://gitlab.teklia.com/ner/nerval) submodule:
```shell
git submodule update --init --recursive
```
Then you can install it via pip:
```shell
......
......@@ -4,8 +4,6 @@
Use the `teklia-dan dataset analyze` command to analyze a dataset. This will display statistics in [Markdown](https://www.markdownguide.org/) format.
The available arguments are:
| Parameter | Description | Type | Default |
| --------------- | -------------------------------- | -------------- | ------- |
| `--labels` | Path to the `labels.json` file. | `pathlib.Path` | |
......
......@@ -8,20 +8,20 @@ Use the `teklia-dan dataset extract` command to extract a dataset from an Arkind
- Store the set of characters encountered in the dataset (in the `charset.pkl` file),
- Generate the resources needed to build a n-gram language model at character, subword or word-level with [kenlm](https://github.com/kpu/kenlm) (in the `language_model/` folder).
| Parameter | Description | Type | Default |
| -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- |
| `database` | Path to an Arkindex export database in SQLite format. | `pathlib.Path` | |
| `--dataset-id ` | ID of the dataset to extract from Arkindex. | `uuid` | |
| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | |
| `--output` | Folder where the data will be generated. | `pathlib.Path` | |
| `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text (see [dedicated section](#examples)). | `str` | |
| `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `⁇` |
| `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | |
| `--transcription-worker-version` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | |
| `--entity-worker-version` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | |
| `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` |
| `--allow-empty` | Elements with no transcriptions are skipped by default. This flag disables this behaviour. | `bool` | `False` |
| `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` |
| Parameter | Description | Type | Default |
| --------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------- | ------- |
| `database` | Path to an Arkindex export database in SQLite format. | `pathlib.Path` | |
| `--dataset-id ` | ID of the dataset to extract from Arkindex. | `uuid` | |
| `--element-type` | Type of the elements to extract. You may specify multiple types. | `str` | |
| `--output` | Folder where the data will be generated. | `pathlib.Path` | |
| `--entity-separators` | Removes all text that does not appear in an entity or in the list of given ordered characters. If several separators follow each other, keep only the first to appear in the list. Do not give any arguments to keep the whole text (see [dedicated section](#examples)). | `str` | |
| `--unknown-token` | Token to use to replace character in the validation/test sets that is not included in the training set. | `str` | `⁇` |
| `--tokens` | Mapping between starting tokens and end tokens to extract text with their entities. | `pathlib.Path` | |
| `--transcription-worker-versions` | Filter transcriptions by worker_version. Use `manual` for manual filtering. | `str` or `uuid` | |
| `--entity-worker-versions` | Filter transcriptions entities by worker_version. Use `manual` for manual filtering | `str` or `uuid` | |
| `--keep-spaces` | Transcriptions are trimmed by default. Use this flag to disable this behaviour. | `bool` | `False` |
| `--allow-empty` | Elements with no transcriptions are skipped by default. This flag disables this behaviour. | `bool` | `False` |
| `--subword-vocab-size` | Size of the vocabulary used to train the sentencepiece subword tokenizer used to train the optional language model. | `int` | `1000` |
The `--tokens` argument expects a YAML-formatted file with a specific format. A list of entries with each entry describing a NER entity. The label of the entity is the key to a dict mapping the starting and ending tokens respectively. This file can be generated by the `teklia-dan dataset tokens` command. More details in the [dedicated page](./tokens.md).
......
# Evaluation
## Description
Use the `teklia-dan evaluate` command to evaluate a trained DAN model.
To evaluate DAN on your dataset:
1. Create a JSON configuration file. You can base the configuration file off the training one. Refer to the [dedicated page](../train/config.md) for a description of parameters.
1. Run `teklia-dan evaluate --config path/to/your/config.json`.
1. Evaluation results for every split are available in the `results` subfolder of the output folder indicated in your configuration.
1. A metrics Markdown table, providing results for each evaluated split, is also printed in the console (see table example below).
### Example output - Metrics Markdown table
This will, for each evaluated split:
1. Create a YAML file with the evaluation results in the `results` subfolder of the `training.output_folder` indicated in your configuration.
1. Print in the console a metrics Markdown table (see [HTR example below](#htr-evaluation)).
1. Print in the console a [Nerval](https://gitlab.teklia.com/ner/nerval) metrics Markdown table, if the `dataset.tokens` parameter in your configuration is defined (see [HTR and NER example below](#htr-and-ner-evaluation)).
1. Print in the console the 5 worst predictions (see [examples below](#examples)).
!!! warning
The display of the worst predictions does not support batch evaluation. If the `training.data.batch_size` parameter is not equal to `1`, then the `WER` displayed is the `WER` of the **whole batch** and not just the image.
| Parameter | Description | Type | Default |
| -------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ------- |
| `--config` | Path to the configuration file. | `pathlib.Path` | |
| `--nerval-threshold` | Distance threshold for the match between gold and predicted entity during Nerval evaluation. `0` would impose perfect matches, `1` would allow completely different strings to be considered as a match. | `float` | `0.3` |
## Examples
### HTR evaluation
```
#### DAN evaluation
| Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) |
| :---: | :-----------: | :-------: | :-----------: | :-------: | :----------------: |
| train | x | x | x | x | x |
| val | x | x | x | x | x |
| test | x | x | x | x | x |
#### 5 worst prediction(s)
| Image name | WER | Alignment between ground truth - prediction |
| :------------: | :-: | :-----------------------------------------: |
| <image_id>.png | x | x |
| | | | |
| | | x |
```
### HTR and NER evaluation
```
#### DAN evaluation
| Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) | NER |
| :---: | :-----------: | :-------: | :-----------: | :-------: | :----------------: | :-: |
| train | x | x | x | x | x | x |
| val | x | x | x | x | x | x |
| test | x | x | x | x | x | x |
#### Nerval evaluation
##### train
| tag | predicted | matched | Precision | Recall | F1 | Support |
| :-----: | :-------: | :-----: | :-------: | :----: | :-: | :-----: |
| Surname | x | x | x | x | x | x |
| All | x | x | x | x | x | x |
##### val
| tag | predicted | matched | Precision | Recall | F1 | Support |
| :-----: | :-------: | :-----: | :-------: | :----: | :-: | :-----: |
| Surname | x | x | x | x | x | x |
| All | x | x | x | x | x | x |
##### test
| tag | predicted | matched | Precision | Recall | F1 | Support |
| :-----: | :-------: | :-----: | :-------: | :----: | :-: | :-----: |
| Surname | x | x | x | x | x | x |
| All | x | x | x | x | x | x |
#### 5 worst prediction(s)
| Image name | WER | Alignment between ground truth - prediction |
| :------------: | :-: | :-----------------------------------------: |
| <image_id>.png | x | x |
| | | | |
| | | x |
```
# Prediction
Use the `teklia-dan predict` command to apply a trained DAN model on an image.
## Description
Use the `teklia-dan predict` command to apply a trained DAN model on an image.
| Parameter | Description | Type | Default |
| --------------------------- | ------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ------------- |
| `--image-dir` | Path to the folder where the images to predict are stored. Must not be provided with `--image`. | `pathlib.Path` | |
......
......@@ -144,7 +144,7 @@ extra:
link: https://teklia.com
- icon: fontawesome/brands/gitlab
name: Git repository for this project
link: https://gitlab.com/teklia/atr/dan
link: https://gitlab.teklia.com/atr/dan
- icon: fontawesome/brands/linkedin
name: Teklia @ LinkedIn
link: https://www.linkedin.com/company/teklia
......