From c1ab7ab8a445a4352f33c3675030cea2f4995273 Mon Sep 17 00:00:00 2001 From: manonBlanco <blanco@teklia.com> Date: Fri, 22 Dec 2023 14:15:41 +0100 Subject: [PATCH] Correctly support batch --- dan/ocr/evaluate.py | 7 +++---- dan/ocr/manager/metrics.py | 8 ++------ dan/ocr/manager/training.py | 12 +++++++++++- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py index 1c22b57a..ce48c2f0 100644 --- a/dan/ocr/evaluate.py +++ b/dan/ocr/evaluate.py @@ -113,10 +113,9 @@ def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool): def inferences_to_parsed_bio(attr: str): bio_values = [] for inference in inferences: - values = getattr(inference, attr) - for value in values: - bio_value = convert(value, ner_tokens=tokens) - bio_values.extend(bio_value.split("\n")) + value = getattr(inference, attr) + bio_value = convert(value, ner_tokens=tokens) + bio_values.extend(bio_value.split("\n")) # Parse this BIO format return parse_bio(bio_values) diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py index 305ddbc7..d7b3f6f2 100644 --- a/dan/ocr/manager/metrics.py +++ b/dan/ocr/manager/metrics.py @@ -26,8 +26,8 @@ METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"} @dataclass class Inference: - ground_truth: List[str] - prediction: List[str] + ground_truth: str + prediction: str class MetricManager: @@ -47,9 +47,6 @@ class MetricManager: self.metric_names: List[str] = metric_names self.epoch_metrics = defaultdict(list) - # List of inferences (prediction with their ground truth) - self.inferences = [] - def format_string_for_cer(self, text: str, remove_token: bool = False): """ Format string for CER computation: remove layout tokens and extra spaces @@ -165,7 +162,6 @@ class MetricManager: metrics["time"] = [values["time"]] gt, prediction = values["str_y"], values["str_x"] - self.inferences.append(Inference(ground_truth=gt, prediction=prediction)) for metric_name in metric_names: match metric_name: diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index ab09e794..7416768c 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -4,6 +4,7 @@ import os import random from copy import deepcopy from enum import Enum +from itertools import starmap from pathlib import Path from time import time from typing import Dict, List, Tuple @@ -768,6 +769,9 @@ class GenericTrainingManager: tokens=self.tokens, ) + # Keep inferences in memory to evaluate with Nerval + inferences = [] + with tqdm(total=len(loader.dataset)) as pbar: pbar.set_description("Evaluation") with torch.no_grad(): @@ -792,6 +796,12 @@ class GenericTrainingManager: pbar.set_postfix(values=str(display_values)) pbar.update(len(batch_data["names"]) * self.nb_workers) + inferences.extend( + starmap( + Inference, zip(batch_values["str_y"], batch_values["str_x"]) + ) + ) + # log metrics in MLflow logging_name = custom_name.split("-")[1] logging_tags_metrics( @@ -810,7 +820,7 @@ class GenericTrainingManager: # Log mlflow artifacts mlflow.log_artifact(path, "predictions") - return metrics, self.metric_manager[custom_name].inferences + return metrics, inferences def output_pred(self, name): path = self.paths["results"] / "predict_{}_{}.yaml".format( -- GitLab