diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py index f8b09a23e018718c2160dc67740efeab931ac848..6cc95a03bd7894a9fbb09a320c0f92e4156a48a1 100644 --- a/dan/ocr/evaluate.py +++ b/dan/ocr/evaluate.py @@ -51,6 +51,9 @@ def eval(rank, config, mlflow_logging): model.load_model() metrics = ["cer", "wer", "wer_no_punct", "time"] + if config["dataset"]["tokens"] is not None: + metrics.append("ner") + for dataset_name in config["dataset"]["datasets"]: for set_name in ["test", "val", "train"]: logger.info(f"Evaluating on set `{set_name}`") diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py index f48da65ca20d8976ab2067ef7828bd20b85f996e..f6b6ec74d50aca424411869a6fc2597aff4b6eb6 100644 --- a/dan/ocr/manager/metrics.py +++ b/dan/ocr/manager/metrics.py @@ -32,6 +32,8 @@ class MetricManager: + list(map(attrgetter("end"), tokens.values())) ) self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])") + self.keep_tokens: re.Pattern = re.compile(r"([^" + layout_tokens + "])") + self.metric_names: List[str] = metric_names self.epoch_metrics = defaultdict(list) @@ -69,6 +71,12 @@ class MetricManager: text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text) return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip() + def format_string_for_ner(self, text: str): + """ + Format string for NER computation: only keep layout tokens + """ + return self.keep_tokens.sub("", text) + def update_metrics(self, batch_metrics): """ Add batch metrics to the metrics @@ -100,6 +108,8 @@ class MetricManager: case "wer" | "wer_no_punct": suffix = metric_name[3:] num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix + case "ner": + num_name, denom_name = "edit_tokens", "nb_tokens" case "loss" | "loss_ce": display_values[metric_name] = round( float( @@ -152,6 +162,13 @@ class MetricManager: map(editdistance.eval, split_gt, split_pred) ) metrics["nb_words" + suffix] = list(map(len, split_gt)) + case "ner": + split_gt = list(map(self.format_string_for_ner, gt)) + split_pred = list(map(self.format_string_for_ner, prediction)) + metrics["edit_tokens"] = list( + map(editdistance.eval, split_gt, split_pred) + ) + metrics["nb_tokens"] = list(map(len, split_gt)) case "loss" | "loss_ce": metrics[metric_name] = [ values[metric_name], diff --git a/dan/ocr/train.py b/dan/ocr/train.py index 8385111c2db9c7cc175148d14a3f49ae333cb9c6..ffc5aa158845e1cb93dea91dca91bade139e6921 100644 --- a/dan/ocr/train.py +++ b/dan/ocr/train.py @@ -34,6 +34,12 @@ def train(rank, params, mlflow_logging=False): model = Manager(params) model.load_model() + if params["dataset"]["tokens"] is not None: + if "ner" not in params["training"]["metrics"]["train"]: + params["training"]["metrics"]["train"].append("ner") + if "ner" not in params["training"]["metrics"]["eval"]: + params["training"]["metrics"]["eval"].append("ner") + if mlflow_logging: logger.info("MLflow logging enabled")