Skip to content
Snippets Groups Projects
Commit a0352637 authored by Eva Bardou's avatar Eva Bardou :frog: Committed by Yoann Schneider
Browse files

Evaluate NER tokens distance

parent 4df2bbd3
No related branches found
No related tags found
1 merge request!327Evaluate NER tokens distance
...@@ -51,6 +51,9 @@ def eval(rank, config, mlflow_logging): ...@@ -51,6 +51,9 @@ def eval(rank, config, mlflow_logging):
model.load_model() model.load_model()
metrics = ["cer", "wer", "wer_no_punct", "time"] metrics = ["cer", "wer", "wer_no_punct", "time"]
if config["dataset"]["tokens"] is not None:
metrics.append("ner")
for dataset_name in config["dataset"]["datasets"]: for dataset_name in config["dataset"]["datasets"]:
for set_name in ["test", "val", "train"]: for set_name in ["test", "val", "train"]:
logger.info(f"Evaluating on set `{set_name}`") logger.info(f"Evaluating on set `{set_name}`")
......
...@@ -32,6 +32,8 @@ class MetricManager: ...@@ -32,6 +32,8 @@ class MetricManager:
+ list(map(attrgetter("end"), tokens.values())) + list(map(attrgetter("end"), tokens.values()))
) )
self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])") self.remove_tokens: re.Pattern = re.compile(r"([" + layout_tokens + "])")
self.keep_tokens: re.Pattern = re.compile(r"([^" + layout_tokens + "])")
self.metric_names: List[str] = metric_names self.metric_names: List[str] = metric_names
self.epoch_metrics = defaultdict(list) self.epoch_metrics = defaultdict(list)
...@@ -69,6 +71,12 @@ class MetricManager: ...@@ -69,6 +71,12 @@ class MetricManager:
text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text) text = REGEX_CONSECUTIVE_LINEBREAKS.sub("\n", text)
return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip() return REGEX_CONSECUTIVE_SPACES.sub(" ", text).strip()
def format_string_for_ner(self, text: str):
"""
Format string for NER computation: only keep layout tokens
"""
return self.keep_tokens.sub("", text)
def update_metrics(self, batch_metrics): def update_metrics(self, batch_metrics):
""" """
Add batch metrics to the metrics Add batch metrics to the metrics
...@@ -100,6 +108,8 @@ class MetricManager: ...@@ -100,6 +108,8 @@ class MetricManager:
case "wer" | "wer_no_punct": case "wer" | "wer_no_punct":
suffix = metric_name[3:] suffix = metric_name[3:]
num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix num_name, denom_name = "edit_words" + suffix, "nb_words" + suffix
case "ner":
num_name, denom_name = "edit_tokens", "nb_tokens"
case "loss" | "loss_ce": case "loss" | "loss_ce":
display_values[metric_name] = round( display_values[metric_name] = round(
float( float(
...@@ -152,6 +162,13 @@ class MetricManager: ...@@ -152,6 +162,13 @@ class MetricManager:
map(editdistance.eval, split_gt, split_pred) map(editdistance.eval, split_gt, split_pred)
) )
metrics["nb_words" + suffix] = list(map(len, split_gt)) metrics["nb_words" + suffix] = list(map(len, split_gt))
case "ner":
split_gt = list(map(self.format_string_for_ner, gt))
split_pred = list(map(self.format_string_for_ner, prediction))
metrics["edit_tokens"] = list(
map(editdistance.eval, split_gt, split_pred)
)
metrics["nb_tokens"] = list(map(len, split_gt))
case "loss" | "loss_ce": case "loss" | "loss_ce":
metrics[metric_name] = [ metrics[metric_name] = [
values[metric_name], values[metric_name],
......
...@@ -34,6 +34,12 @@ def train(rank, params, mlflow_logging=False): ...@@ -34,6 +34,12 @@ def train(rank, params, mlflow_logging=False):
model = Manager(params) model = Manager(params)
model.load_model() model.load_model()
if params["dataset"]["tokens"] is not None:
if "ner" not in params["training"]["metrics"]["train"]:
params["training"]["metrics"]["train"].append("ner")
if "ner" not in params["training"]["metrics"]["eval"]:
params["training"]["metrics"]["eval"].append("ner")
if mlflow_logging: if mlflow_logging:
logger.info("MLflow logging enabled") logger.info("MLflow logging enabled")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment