From c1ab7ab8a445a4352f33c3675030cea2f4995273 Mon Sep 17 00:00:00 2001
From: manonBlanco <blanco@teklia.com>
Date: Fri, 22 Dec 2023 14:15:41 +0100
Subject: [PATCH] Correctly support batch

---
 dan/ocr/evaluate.py         |  7 +++----
 dan/ocr/manager/metrics.py  |  8 ++------
 dan/ocr/manager/training.py | 12 +++++++++++-
 3 files changed, 16 insertions(+), 11 deletions(-)

diff --git a/dan/ocr/evaluate.py b/dan/ocr/evaluate.py
index 1c22b57a..ce48c2f0 100644
--- a/dan/ocr/evaluate.py
+++ b/dan/ocr/evaluate.py
@@ -113,10 +113,9 @@ def eval(rank, config: dict, nerval_threshold: float, mlflow_logging: bool):
     def inferences_to_parsed_bio(attr: str):
         bio_values = []
         for inference in inferences:
-            values = getattr(inference, attr)
-            for value in values:
-                bio_value = convert(value, ner_tokens=tokens)
-                bio_values.extend(bio_value.split("\n"))
+            value = getattr(inference, attr)
+            bio_value = convert(value, ner_tokens=tokens)
+            bio_values.extend(bio_value.split("\n"))
 
         # Parse this BIO format
         return parse_bio(bio_values)
diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py
index 305ddbc7..d7b3f6f2 100644
--- a/dan/ocr/manager/metrics.py
+++ b/dan/ocr/manager/metrics.py
@@ -26,8 +26,8 @@ METRICS_KEYWORD = {"cer": "chars", "wer": "words", "ner": "tokens"}
 
 @dataclass
 class Inference:
-    ground_truth: List[str]
-    prediction: List[str]
+    ground_truth: str
+    prediction: str
 
 
 class MetricManager:
@@ -47,9 +47,6 @@ class MetricManager:
         self.metric_names: List[str] = metric_names
         self.epoch_metrics = defaultdict(list)
 
-        # List of inferences (prediction with their ground truth)
-        self.inferences = []
-
     def format_string_for_cer(self, text: str, remove_token: bool = False):
         """
         Format string for CER computation: remove layout tokens and extra spaces
@@ -165,7 +162,6 @@ class MetricManager:
             metrics["time"] = [values["time"]]
 
         gt, prediction = values["str_y"], values["str_x"]
-        self.inferences.append(Inference(ground_truth=gt, prediction=prediction))
 
         for metric_name in metric_names:
             match metric_name:
diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py
index ab09e794..7416768c 100644
--- a/dan/ocr/manager/training.py
+++ b/dan/ocr/manager/training.py
@@ -4,6 +4,7 @@ import os
 import random
 from copy import deepcopy
 from enum import Enum
+from itertools import starmap
 from pathlib import Path
 from time import time
 from typing import Dict, List, Tuple
@@ -768,6 +769,9 @@ class GenericTrainingManager:
             tokens=self.tokens,
         )
 
+        # Keep inferences in memory to evaluate with Nerval
+        inferences = []
+
         with tqdm(total=len(loader.dataset)) as pbar:
             pbar.set_description("Evaluation")
             with torch.no_grad():
@@ -792,6 +796,12 @@ class GenericTrainingManager:
                     pbar.set_postfix(values=str(display_values))
                     pbar.update(len(batch_data["names"]) * self.nb_workers)
 
+                    inferences.extend(
+                        starmap(
+                            Inference, zip(batch_values["str_y"], batch_values["str_x"])
+                        )
+                    )
+
                 # log metrics in MLflow
                 logging_name = custom_name.split("-")[1]
                 logging_tags_metrics(
@@ -810,7 +820,7 @@ class GenericTrainingManager:
             # Log mlflow artifacts
             mlflow.log_artifact(path, "predictions")
 
-        return metrics, self.metric_manager[custom_name].inferences
+        return metrics, inferences
 
     def output_pred(self, name):
         path = self.paths["results"] / "predict_{}_{}.yaml".format(
-- 
GitLab