From a035263782e8a94907981df86efbfb2939c32c65 Mon Sep 17 00:00:00 2001 From: Eva Bardou <bardou@teklia.com> Date: Thu, 23 Nov 2023 11:06:21 +0000 Subject: [PATCH] Evaluate NER tokens distance --- dan/ocr/evaluate.py | 3 +++ dan/ocr/manager/metrics.py | 17 +++++++++++++++++ dan/ocr/train.py | 6 ++++++ 3 files changed, 26 insertions(+) diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py index f8b09a23..6cc95a03 100644 --- a/dan/ocr/evaluate.py +++ b/dan/ocr/evaluate.py @@ -51,6 +51,9 @@ def eval(rank, config, mlflow_logging): model.load_model() metrics = ["cer", "wer", "wer_no_punct", "time"] + if config["dataset"]["tokens"] is not None: + metrics.append("ner") + for dataset_name in config["dataset"]["datasets"]: for set_name in ["test", "val", "train"]: logger.info(f"Evaluating on set `{set_name}`") diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py index f48da65c..f6b6ec74 100644 --- a/dan/ocr/manager/metrics.py +++ b/dan/ocr/manager/metrics.py @@ -32,6 +32,8 @@ class MetricManager: + list(map(attrgetter("end"), tokens.values())) ) self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])") + self.keep_tokens: re.Pattern = re.compile(r"([^" + layout_tokens + "])") + self.metric_names: List[str] = metric_names self.epoch_metrics = defaultdict(list) @@ -69,6 +71,12 @@ class MetricManager: text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text) return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip() + def format_string_for_ner(self, text: str): + """ + Format string for NER computation: only keep layout tokens + """ + return self.keep_tokens.sub("", text) + def update_metrics(self, batch_metrics): """ Add batch metrics to the metrics @@ -100,6 +108,8 @@ class MetricManager: case "wer" | "wer_no_punct": suffix = metric_name[3:] num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix + case "ner": + num_name, denom_name = "edit_tokens", "nb_tokens" case "loss" | "loss_ce": display_values[metric_name] = round( float( @@ -152,6 +162,13 @@ class MetricManager: map(editdistance.eval, split_gt, split_pred) ) metrics["nb_words" + suffix] = list(map(len, split_gt)) + case "ner": + split_gt = list(map(self.format_string_for_ner, gt)) + split_pred = list(map(self.format_string_for_ner, prediction)) + metrics["edit_tokens"] = list( + map(editdistance.eval, split_gt, split_pred) + ) + metrics["nb_tokens"] = list(map(len, split_gt)) case "loss" | "loss_ce": metrics[metric_name] = [ values[metric_name], diff --git a/dan/ocr/train.py b/dan/ocr/train.py index 8385111c..ffc5aa15 100644 --- a/dan/ocr/train.py +++ b/dan/ocr/train.py @@ -34,6 +34,12 @@ def train(rank, params, mlflow_logging=False): model = Manager(params) model.load_model() + if params["dataset"]["tokens"] is not None: + if "ner" not in params["training"]["metrics"]["train"]: + params["training"]["metrics"]["train"].append("ner") + if "ner" not in params["training"]["metrics"]["eval"]: + params["training"]["metrics"]["eval"].append("ner") + if mlflow_logging: logger.info("MLflow logging enabled") -- GitLab