Skip to content
Snippets Groups Projects
Commit c1ab7ab8 authored by Manon Blanco's avatar Manon Blanco
Browse files

Correctly support batch

parent b96f5824
No related branches found
No related tags found
No related merge requests found
......@@ -113,10 +113,9 @@ def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool):
def inferences_to_parsed_bio(attr: str):
bio_values = []
for inference in inferences:
values = getattr(inference, attr)
for value in values:
bio_value = convert(value, ner_tokens=tokens)
bio_values.extend(bio_value.split("\n"))
value = getattr(inference, attr)
bio_value = convert(value, ner_tokens=tokens)
bio_values.extend(bio_value.split("\n"))
# Parse this BIO format
return parse_bio(bio_values)
......
......@@ -26,8 +26,8 @@ METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"}
@dataclass
class Inference:
ground_truth: List[str]
prediction: List[str]
ground_truth: str
prediction: str
class MetricManager:
......@@ -47,9 +47,6 @@ class MetricManager:
self.metric_names: List[str] = metric_names
self.epoch_metrics = defaultdict(list)
# List of inferences (prediction with their ground truth)
self.inferences = []
def format_string_for_cer(self, text: str, remove_token: bool = False):
"""
Format string for CER computation: remove layout tokens and extra spaces
......@@ -165,7 +162,6 @@ class MetricManager:
metrics["time"] = [values["time"]]
gt, prediction = values["str_y"], values["str_x"]
self.inferences.append(Inference(ground_truth=gt, prediction=prediction))
for metric_name in metric_names:
match metric_name:
......
......@@ -4,6 +4,7 @@ import os
import random
from copy import deepcopy
from enum import Enum
from itertools import starmap
from pathlib import Path
from time import time
from typing import Dict, List, Tuple
......@@ -768,6 +769,9 @@ class GenericTrainingManager:
tokens=self.tokens,
)
# Keep inferences in memory to evaluate with Nerval
inferences = []
with tqdm(total=len(loader.dataset)) as pbar:
pbar.set_description("Evaluation")
with torch.no_grad():
......@@ -792,6 +796,12 @@ class GenericTrainingManager:
pbar.set_postfix(values=str(display_values))
pbar.update(len(batch_data["names"]) * self.nb_workers)
inferences.extend(
starmap(
Inference, zip(batch_values["str_y"], batch_values["str_x"])
)
)
# log metrics in MLflow
logging_name = custom_name.split("-")[1]
logging_tags_metrics(
......@@ -810,7 +820,7 @@ class GenericTrainingManager:
# Log mlflow artifacts
mlflow.log_artifact(path, "predictions")
return metrics, self.metric_manager[custom_name].inferences
return metrics, inferences
def output_pred(self, name):
path = self.paths["results"] / "predict_{}_{}.yaml".format(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment