Skip to content
Snippets Groups Projects
Commit 65128721 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Merge branch 'evaluate-ner-distance' into 'main'

Evaluate NER tokens distance

Closes #237

See merge request !327
parents 4df2bbd3 a0352637
No related branches found
No related tags found
1 merge request!327Evaluate NER tokens distance
......@@ -51,6 +51,9 @@ def eval(rank, config, mlflow_logging):
model.load_model()
metrics = ["cer", "wer", "wer_no_punct", "time"]
if config["dataset"]["tokens"] is not None:
metrics.append("ner")
for dataset_name in config["dataset"]["datasets"]:
for set_name in ["test", "val", "train"]:
logger.info(f"Evaluating on set `{set_name}`")
......
......@@ -32,6 +32,8 @@ class MetricManager:
+ list(map(attrgetter("end"), tokens.values()))
)
self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])")
self.keep_tokens: re.Pattern = re.compile(r"([^" + layout_tokens + "])")
self.metric_names: List[str] = metric_names
self.epoch_metrics = defaultdict(list)
......@@ -69,6 +71,12 @@ class MetricManager:
text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text)
return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip()
def format_string_for_ner(self, text: str):
"""
Format string for NER computation: only keep layout tokens
"""
return self.keep_tokens.sub("", text)
def update_metrics(self, batch_metrics):
"""
Add batch metrics to the metrics
......@@ -100,6 +108,8 @@ class MetricManager:
case "wer" | "wer_no_punct":
suffix = metric_name[3:]
num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix
case "ner":
num_name, denom_name = "edit_tokens", "nb_tokens"
case "loss" | "loss_ce":
display_values[metric_name] = round(
float(
......@@ -152,6 +162,13 @@ class MetricManager:
map(editdistance.eval, split_gt, split_pred)
)
metrics["nb_words" + suffix] = list(map(len, split_gt))
case "ner":
split_gt = list(map(self.format_string_for_ner, gt))
split_pred = list(map(self.format_string_for_ner, prediction))
metrics["edit_tokens"] = list(
map(editdistance.eval, split_gt, split_pred)
)
metrics["nb_tokens"] = list(map(len, split_gt))
case "loss" | "loss_ce":
metrics[metric_name] = [
values[metric_name],
......
......@@ -34,6 +34,12 @@ def train(rank, params, mlflow_logging=False):
model = Manager(params)
model.load_model()
if params["dataset"]["tokens"] is not None:
if "ner" not in params["training"]["metrics"]["train"]:
params["training"]["metrics"]["train"].append("ner")
if "ner" not in params["training"]["metrics"]["eval"]:
params["training"]["metrics"]["eval"].append("ner")
if mlflow_logging:
logger.info("MLflow logging enabled")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment