diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2ce3e47f9308d2143856d095c42cc376810787cb..0b1099825f4e1b0ee1ad9cd946b7846b79fd94c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,7 +44,7 @@ repos: rev: 0.7.16 hooks: - id: mdformat - exclude: tests/data/analyze + exclude: tests/data/analyze|tests/data/evaluate/metrics_table.md # Optionally add plugins additional_dependencies: - mdformat-mkdocs[recommended] diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py index d998ff1585c023de9540270871dc3ee64dcd1b9a..7d514a40714df6c6eb846f7e7463a570affc6c29 100644 --- a/dan/ocr/evaluate.py +++ b/dan/ocr/evaluate.py @@ -11,7 +11,7 @@ import torch import torch.multiprocessing as mp from dan.ocr.manager.training import Manager -from dan.ocr.utils import update_config +from dan.ocr.utils import add_metrics_table_row, create_metrics_table, update_config from dan.utils import read_json logger = logging.getLogger(__name__) @@ -50,23 +50,34 @@ def eval(rank, config, mlflow_logging): model = Manager(config) model.load_model() - metrics = ["cer", "cer_no_token", "wer", "wer_no_punct", "wer_no_token", "time"] + metric_names = [ + "cer", + "cer_no_token", + "wer", + "wer_no_punct", + "wer_no_token", + "time", + ] if config["dataset"]["tokens"] is not None: - metrics.append("ner") + metric_names.append("ner") + metrics_table = create_metrics_table(metric_names) for dataset_name in config["dataset"]["datasets"]: - for set_name in ["test", "val", "train"]: + for set_name in ["train", "val", "test"]: logger.info(f"Evaluating on set `{set_name}`") - model.evaluate( + metrics = model.evaluate( "{}-{}".format(dataset_name, set_name), [ (dataset_name, set_name), ], - metrics, - output=True, + metric_names, mlflow_logging=mlflow_logging, ) + add_metrics_table_row(metrics_table, set_name, metrics) + + print(metrics_table) + def run(config: dict): update_config(config) diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index ff35cff52c00580a3612b0b5cee4130e207d8a2e..ba5160fa1d3ad9ca1037077292b3530861520231 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -749,8 +749,8 @@ class GenericTrainingManager: return display_values def evaluate( - self, custom_name, sets_list, metric_names, mlflow_logging=False, output=False - ): + self, custom_name, sets_list, metric_names, mlflow_logging=False + ) -> Dict[str, int | float]: """ Main loop for evaluation """ @@ -798,19 +798,19 @@ class GenericTrainingManager: display_values, logging_name, mlflow_logging, self.is_master ) - # output metrics values if requested - if output: - if "pred" in metric_names: - self.output_pred(custom_name) - metrics = self.metric_manager[custom_name].get_display_values(output=True) - path = self.paths["results"] / "predict_{}_{}.yaml".format( - custom_name, self.latest_epoch - ) - path.write_text(yaml.dump(metrics)) + if "pred" in metric_names: + self.output_pred(custom_name) + metrics = self.metric_manager[custom_name].get_display_values(output=True) + path = self.paths["results"] / "predict_{}_{}.yaml".format( + custom_name, self.latest_epoch + ) + path.write_text(yaml.dump(metrics)) + + if mlflow_logging: + # Log mlflow artifacts + mlflow.log_artifact(path, "predictions") - if mlflow_logging: - # Log mlflow artifacts - mlflow.log_artifact(path, "predictions") + return metrics def output_pred(self, name): path = self.paths["results"] / "predict_{}_{}.yaml".format( diff --git a/dan/ocr/utils.py b/dan/ocr/utils.py index 7142b0cb1c94b37dd4fab9f5c05705c9c694d646..988ff985f81c0f0f18bfd31d30fb8e348e2edc19 100644 --- a/dan/ocr/utils.py +++ b/dan/ocr/utils.py @@ -1,13 +1,25 @@ # -*- coding: utf-8 -*- from pathlib import Path +from typing import Dict, List, Optional import torch +from prettytable import MARKDOWN, PrettyTable from torch.optim import Adam from dan.ocr.decoder import GlobalHTADecoder from dan.ocr.encoder import FCN_Encoder from dan.ocr.transforms import Preprocessing +METRICS_TABLE_HEADER = { + "cer": "CER (HTR-NER)", + "cer_no_token": "CER (HTR)", + "wer": "WER (HTR-NER)", + "wer_no_token": "WER (HTR)", + "wer_no_punct": "WER (HTR no punct)", + "ner": "NER", +} +REVERSE_HEADER = {column: metric for metric, column in METRICS_TABLE_HEADER.items()} + def update_config(config: dict): """ @@ -51,3 +63,36 @@ def update_config(config: dict): # set nb_gpu if not present if config["training"]["device"]["nb_gpu"] is None: config["training"]["device"]["nb_gpu"] = torch.cuda.device_count() + + +def create_metrics_table(metrics: List[str]) -> PrettyTable: + """ + Create a Markdown table to display metrics in (CER, WER, NER, etc) + for each evaluated split. + """ + table = PrettyTable( + field_names=["Split"] + + [title for metric, title in METRICS_TABLE_HEADER.items() if metric in metrics] + ) + table.set_style(MARKDOWN) + + return table + + +def add_metrics_table_row( + table: PrettyTable, split: str, metrics: Optional[Dict[str, int | float]] +) -> PrettyTable: + """ + Add a row to an existing metrics Markdown table for the currently evaluated split. + To create such table please refer to + [create_metrics_table][dan.ocr.utils.create_metrics_table] function. + """ + row = [split] + for column in table.field_names: + if column not in REVERSE_HEADER: + continue + + metric_name = REVERSE_HEADER[column] + row.append(metrics.get(metric_name, "−")) + + table.add_row(row) diff --git a/docs/ref/ocr/utils.md b/docs/ref/ocr/utils.md new file mode 100644 index 0000000000000000000000000000000000000000..adeec07327016e1240b904610c8b5aec282a720e --- /dev/null +++ b/docs/ref/ocr/utils.md @@ -0,0 +1,3 @@ +# Utils + +::: dan.ocr.utils diff --git a/docs/usage/evaluate/index.md b/docs/usage/evaluate/index.md index eedfb14989b89327e83d488ed82c27ced6642099..f1cc273be210825c853219f5182625f0dc19f21a 100644 --- a/docs/usage/evaluate/index.md +++ b/docs/usage/evaluate/index.md @@ -7,3 +7,12 @@ To evaluate DAN on your dataset: 1. Create a JSON configuration file. You can base the configuration file off the training one. Refer to the [dedicated page](../train/config.md) for a description of parameters. 1. Run `teklia-dan evaluate --config path/to/your/config.json`. 1. Evaluation results for every split are available in the `results` subfolder of the output folder indicated in your configuration. +1. A metrics Markdown table, providing results for each evaluated split, is also printed in the console (see table example below). + +### Example output - Metrics Markdown table + +| Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) | NER | +| :---: | :-----------: | :-------: | :-----------: | :-------: | :----------------: | :-: | +| train | x | x | x | x | x | x | +| val | x | x | x | x | x | x | +| test | x | x | x | x | x | x | diff --git a/mkdocs.yml b/mkdocs.yml index 611dcb92e7f373e6767ec285954d422028eb05e8..758964cef9c45bec8799828241470a9f196ca57a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -118,6 +118,7 @@ nav: - MLflow: ref/ocr/mlflow.md - Schedulers: ref/ocr/schedulers.md - Transformations: ref/ocr/transforms.md + - Utils: ref/ocr/utils.md - CLI: ref/cli.md - Utils: ref/utils.md diff --git a/tests/data/evaluate/metrics_table.md b/tests/data/evaluate/metrics_table.md new file mode 100644 index 0000000000000000000000000000000000000000..e9cc598fe4875bb8619ed08af9feef204eb92e8b --- /dev/null +++ b/tests/data/evaluate/metrics_table.md @@ -0,0 +1,5 @@ +| Split | CER (HTR-NER) | CER (HTR) | WER (HTR-NER) | WER (HTR) | WER (HTR no punct) | +|:-----:|:-------------:|:---------:|:-------------:|:---------:|:------------------:| +| train | 1.3023 | 1.3023 | 1.0 | 1.0 | 1.0 | +| val | 1.2683 | 1.2683 | 1.0 | 1.0 | 1.0 | +| test | 1.1224 | 1.1224 | 1.0 | 1.0 | 1.0 | diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 1f73b048b0b57134ec261f90537d4cb507331f5c..425f85ec915adb905f32d761e4c55bb8f47815d4 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -1,14 +1,52 @@ # -*- coding: utf-8 -*- import shutil +from pathlib import Path import pytest import yaml +from prettytable import PrettyTable from dan.ocr import evaluate +from dan.ocr.utils import add_metrics_table_row, create_metrics_table from tests import FIXTURES +def test_create_metrics_table(): + metric_names = ["ignored", "wer", "cer", "time", "ner"] + metrics_table = create_metrics_table(metric_names) + + assert isinstance(metrics_table, PrettyTable) + assert metrics_table.field_names == [ + "Split", + "CER (HTR-NER)", + "WER (HTR-NER)", + "NER", + ] + + +def test_add_metrics_table_row(): + metric_names = ["ignored", "wer", "cer", "time", "ner"] + metrics_table = create_metrics_table(metric_names) + + metrics = { + "ignored": "whatever", + "wer": 1.0, + "cer": 1.3023, + "time": 42, + } + add_metrics_table_row(metrics_table, "train", metrics) + + assert isinstance(metrics_table, PrettyTable) + assert metrics_table.field_names == [ + "Split", + "CER (HTR-NER)", + "WER (HTR-NER)", + "NER", + ] + assert metrics_table.rows == [["train", 1.3023, 1.0, "−"]] + + @pytest.mark.parametrize( "training_res, val_res, test_res", ( @@ -55,7 +93,7 @@ from tests import FIXTURES ), ), ) -def test_evaluate(training_res, val_res, test_res, evaluate_config): +def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config): # Use the tmp_path as base folder evaluate_config["training"]["output_folder"] = FIXTURES / "evaluate" @@ -82,3 +120,11 @@ def test_evaluate(training_res, val_res, test_res, evaluate_config): # Remove results files shutil.rmtree(evaluate_config["training"]["output_folder"] / "results") + + # Check the metrics Markdown table + captured_std = capsys.readouterr() + last_printed_lines = captured_std.out.split("\n")[-6:] + assert ( + "\n".join(last_printed_lines) + == Path(FIXTURES / "evaluate" / "metrics_table.md").read_text() + )