diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..7d6f8a61515402d96d9d86de1546165d8d529a87 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "nerval"] + path = nerval + url = ../../ner/nerval.git diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py index 7d514a40714df6c6eb846f7e7463a570affc6c29..11fdf12d3c1ae5f2b302a863ef7714b85bf0becc 100644 --- a/dan/ocr/evaluate.py +++ b/dan/ocr/evaluate.py @@ -10,9 +10,13 @@ import numpy as np import torch import torch.multiprocessing as mp +from dan.bio import convert from dan.ocr.manager.training import Manager from dan.ocr.utils import add_metrics_table_row, create_metrics_table, update_config -from dan.utils import read_json +from dan.utils import parse_tokens, read_json +from nerval.evaluate import evaluate +from nerval.parse import parse_bio +from nerval.utils import print_results logger = logging.getLogger(__name__) @@ -62,10 +66,12 @@ def eval(rank, config, mlflow_logging): metric_names.append("ner") metrics_table = create_metrics_table(metric_names) + all_inferences = {} + for dataset_name in config["dataset"]["datasets"]: for set_name in ["train", "val", "test"]: logger.info(f"Evaluating on set `{set_name}`") - metrics = model.evaluate( + metrics, inferences = model.evaluate( "{}-{}".format(dataset_name, set_name), [ (dataset_name, set_name), @@ -75,9 +81,39 @@ def eval(rank, config, mlflow_logging): ) add_metrics_table_row(metrics_table, set_name, metrics) + all_inferences[set_name] = inferences print(metrics_table) + if "ner" not in metric_names: + return + + print() + + def inferences_to_parsed_bio(attr: str): + bio_values = [] + for inference in inferences: + values = getattr(inference, attr) + for value in values: + bio_value = convert(value, ner_tokens=tokens) + bio_values.extend(bio_value.split("\n")) + + # Parse this BIO format + return parse_bio(bio_values) + + # Evaluate with Nerval + tokens = parse_tokens(config["dataset"]["tokens"]) + for set_name, inferences in all_inferences.items(): + ground_truths = inferences_to_parsed_bio("ground_truth") + predictions = inferences_to_parsed_bio("prediction") + + if not (ground_truths and predictions): + continue + + scores = evaluate(ground_truths, predictions, 0.30) + print(set_name) + print_results(scores) + def run(config: dict): update_config(config) diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py index 102fddafbaa334845dc9251c630df18a6ae353dc..305ddbc71f1026ec034dff547da36fda56fb7253 100644 --- a/dan/ocr/manager/metrics.py +++ b/dan/ocr/manager/metrics.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import re from collections import defaultdict +from dataclasses import dataclass from operator import attrgetter from pathlib import Path from typing import Dict, List @@ -23,6 +24,12 @@ REGEX_ONLY_ONE_SPACE = re.compile(r"\s+") METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"} +@dataclass +class Inference: + ground_truth: List[str] + prediction: List[str] + + class MetricManager: def __init__(self, metric_names: List[str], dataset_name: str, tokens: Path | None): self.dataset_name: str = dataset_name @@ -40,6 +47,9 @@ class MetricManager: self.metric_names: List[str] = metric_names self.epoch_metrics = defaultdict(list) + # List of inferences (prediction with their ground truth) + self.inferences = [] + def format_string_for_cer(self, text: str, remove_token: bool = False): """ Format string for CER computation: remove layout tokens and extra spaces @@ -155,6 +165,8 @@ class MetricManager: metrics["time"] = [values["time"]] gt, prediction = values["str_y"], values["str_x"] + self.inferences.append(Inference(ground_truth=gt, prediction=prediction)) + for metric_name in metric_names: match metric_name: case ( diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index ba5160fa1d3ad9ca1037077292b3530861520231..ab09e7949c863d5cb4d228393279ad799bc7c44d 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -6,7 +6,7 @@ from copy import deepcopy from enum import Enum from pathlib import Path from time import time -from typing import Dict +from typing import Dict, List, Tuple import numpy as np import torch @@ -20,7 +20,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from dan.ocr.manager.metrics import MetricManager +from dan.ocr.manager.metrics import Inference, MetricManager from dan.ocr.manager.ocr import OCRDatasetManager from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics from dan.ocr.schedulers import DropoutScheduler @@ -750,7 +750,7 @@ class GenericTrainingManager: def evaluate( self, custom_name, sets_list, metric_names, mlflow_logging=False - ) -> Dict[str, int | float]: + ) -> Tuple[Dict[str, int | float], List[Inference]]: """ Main loop for evaluation """ @@ -810,7 +810,7 @@ class GenericTrainingManager: # Log mlflow artifacts mlflow.log_artifact(path, "predictions") - return metrics + return metrics, self.metric_manager[custom_name].inferences def output_pred(self, name): path = self.paths["results"] / "predict_{}_{}.yaml".format( diff --git a/requirements.txt b/requirements.txt index 445189ae68ef6febce26b477a5360a308208a7e3..d065f01db77260ee824b2c39ea3e778332bb93a8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ +-e ./nerval albumentations==1.3.1 arkindex-export==0.1.9 boto3==1.26.124 -editdistance==0.6.2 flashlight-text==0.0.4 imageio==2.26.1 imagesize==1.4.1 @@ -9,7 +9,6 @@ lxml==4.9.3 mdutils==1.6.0 nltk==3.8.1 numpy==1.24.3 -prettytable==3.8.0 PyYAML==6.0 scipy==1.10.1 sentencepiece==0.1.99 diff --git a/tests/data/evaluate/metrics_table.md b/tests/data/evaluate/metrics_table.md index d67456d844071ebb8a9332638a2f527d93c5e405..b3697dab44a1ae2753ae27285f08ff161dbc3a52 100644 --- a/tests/data/evaluate/metrics_table.md +++ b/tests/data/evaluate/metrics_table.md @@ -3,3 +3,38 @@ | train | 18.89 | 21.05 | 26.67 | 26.67 | 26.67 | 7.14 | | val | 8.82 | 11.54 | 50.0 | 50.0 | 50.0 | 0.0 | | test | 2.78 | 3.33 | 14.29 | 14.29 | 14.29 | 0.0 | + +train +| tag | predicted | matched | Precision | Recall | F1 | Support | +|:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:| +| Surname | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 | +| Patron | 2 | 0 | 0.0 | 0.0 | 0 | 1 | +| Operai | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 | +| Louche | 2 | 1 | 0.5 | 0.5 | 0.5 | 2 | +| Koala | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 | +| Firstname | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 | +| Chalumeau | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | +| Batiment | 2 | 2 | 1.0 | 1.0 | 1.0 | 2 | +| All | 15 | 12 | 0.8 | 0.857 | 0.828 | 14 | +val +| tag | predicted | matched | Precision | Recall | F1 | Support | +|:---------:|:---------:|:-------:|:---------:|:------:|:----:|:-------:| +| Surname | 1 | 0 | 0.0 | 0.0 | 0 | 1 | +| Patron | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | +| Operai | 1 | 0 | 0.0 | 0.0 | 0 | 1 | +| Louche | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | +| Koala | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | +| Firstname | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | +| Chalumeau | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | +| Batiment | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | +| All | 8 | 6 | 0.75 | 0.75 | 0.75 | 8 | +test +| tag | predicted | matched | Precision | Recall | F1 | Support | +|:---------:|:---------:|:-------:|:---------:|:------:|:-----:|:-------:| +| Surname | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | +| Louche | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | +| Koala | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | +| Firstname | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | +| Chalumeau | 1 | 0 | 0.0 | 0.0 | 0 | 1 | +| Batiment | 1 | 1 | 1.0 | 1.0 | 1.0 | 1 | +| All | 6 | 5 | 0.833 | 0.833 | 0.833 | 6 | diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 98fbd93898f09b75377b3b8e9cd21ddbc1b7521b..38cec57b7ccfb0ab18d12802938202d9fb65c872 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -129,7 +129,7 @@ def test_evaluate(capsys, training_res, val_res, test_res, evaluate_config): # Check the metrics Markdown table captured_std = capsys.readouterr() - last_printed_lines = captured_std.out.split("\n")[-6:] + last_printed_lines = captured_std.out.split("\n")[-41:] assert ( "\n".join(last_printed_lines) == Path(FIXTURES / "evaluate" / "metrics_table.md").read_text()