From be8b1a667f84c1161f57826ab15b72046cf24a2b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Wed, 24 May 2023 15:08:25 +0200
Subject: [PATCH] Remove ids + clean code

---
 dan/manager/metrics.py  |  1 -
 dan/manager/ocr.py      | 76 ++++++++++++-----------------------------
 dan/manager/training.py |  4 ---
 3 files changed, 22 insertions(+), 59 deletions(-)

diff --git a/dan/manager/metrics.py b/dan/manager/metrics.py
index abb8c166..83aaa78c 100644
--- a/dan/manager/metrics.py
+++ b/dan/manager/metrics.py
@@ -55,7 +55,6 @@ class MetricManager:
         self.epoch_metrics = {
             "nb_samples": list(),
             "names": list(),
-            "ids": list(),
         }
 
         for metric_name in self.metric_names:
diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index a3c44c82..2132ff62 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -146,9 +146,8 @@ class OCRDataset(GenericDataset):
         if "normalize" in self.params["config"] and self.params["config"]["normalize"]:
             sample["img"] = (sample["img"] - self.mean) / self.std
 
-        sample["img_shape"] = sample["img"].shape
         sample["img_reduced_shape"] = np.ceil(
-            sample["img_shape"] / self.reduce_dims_factor
+            sample["img"].shape / self.reduce_dims_factor
         ).astype(int)
 
         if self.set_name == "train":
@@ -157,8 +156,8 @@ class OCRDataset(GenericDataset):
             ]
 
         sample["img_position"] = [
-            [0, sample["img_shape"][0]],
-            [0, sample["img_shape"][1]],
+            [0, sample["img"].shape[0]],
+            [0, sample["img"].shape[1]],
         ]
         # Padding constraints to handle model needs
         if "padding" in self.params["config"] and self.params["config"]["padding"]:
@@ -189,10 +188,6 @@ class OCRDataset(GenericDataset):
                     padding_mode=self.params["config"]["padding"]["mode"],
                     return_position=True,
                 )
-        sample["img_reduced_position"] = [
-            np.ceil(p / factor).astype(int)
-            for p, factor in zip(sample["img_position"], self.reduce_dims_factor[:2])
-        ]
         return sample
 
     def convert_labels(self):
@@ -461,67 +456,40 @@ class OCRCollateFunction:
         self.config = config
 
     def __call__(self, batch_data):
-        names = [batch_data[i]["name"] for i in range(len(batch_data))]
-        ids = [
-            batch_data[i]["name"].split("/")[-1].split(".")[0]
-            for i in range(len(batch_data))
-        ]
-
         labels = [batch_data[i]["token_label"] for i in range(len(batch_data))]
         labels = pad_sequences_1D(labels, padding_value=self.label_padding_value)
         labels = torch.tensor(labels).long()
-        reverse_labels = [
-            [
-                batch_data[i]["token_label"][0],
-            ]
-            + batch_data[i]["token_label"][-2:0:-1]
-            + [
-                batch_data[i]["token_label"][-1],
-            ]
-            for i in range(len(batch_data))
-        ]
-        reverse_labels = pad_sequences_1D(
-            reverse_labels, padding_value=self.label_padding_value
-        )
-        reverse_labels = torch.tensor(reverse_labels).long()
-        labels_len = [batch_data[i]["label_len"] for i in range(len(batch_data))]
-
-        raw_labels = [batch_data[i]["label"] for i in range(len(batch_data))]
-        unchanged_labels = [
-            batch_data[i]["unchanged_label"] for i in range(len(batch_data))
-        ]
 
         padding_mode = (
             self.config["padding_mode"] if "padding_mode" in self.config else "br"
         )
         imgs = [batch_data[i]["img"] for i in range(len(batch_data))]
-        imgs_shape = [batch_data[i]["img_shape"] for i in range(len(batch_data))]
-        imgs_reduced_shape = [
-            batch_data[i]["img_reduced_shape"] for i in range(len(batch_data))
-        ]
-        imgs_position = [batch_data[i]["img_position"] for i in range(len(batch_data))]
-        imgs_reduced_position = [
-            batch_data[i]["img_reduced_position"] for i in range(len(batch_data))
-        ]
         imgs = pad_images(
             imgs, padding_value=self.img_padding_value, padding_mode=padding_mode
         )
         imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2)
+
         formatted_batch_data = {
-            "names": names,
-            "ids": ids,
-            "labels": labels,
-            "reverse_labels": reverse_labels,
-            "raw_labels": raw_labels,
-            "unchanged_labels": unchanged_labels,
-            "labels_len": labels_len,
-            "imgs": imgs,
-            "imgs_shape": imgs_shape,
-            "imgs_reduced_shape": imgs_reduced_shape,
-            "imgs_position": imgs_position,
-            "imgs_reduced_position": imgs_reduced_position,
+            formatted_key: [batch_data[i][initial_key] for i in range(len(batch_data))]
+            for formatted_key, initial_key in zip(
+                [
+                    "names",
+                    "labels_len",
+                    "raw_labels",
+                    "imgs_position",
+                    "imgs_reduced_shape",
+                ],
+                ["name", "label_len", "label", "img_position", "img_reduced_shape"],
+            )
         }
 
+        formatted_batch_data.update(
+            {
+                "imgs": imgs,
+                "labels": labels,
+            }
+        )
+
         return formatted_batch_data
 
 
diff --git a/dan/manager/training.py b/dan/manager/training.py
index 53012a7c..25855e1b 100644
--- a/dan/manager/training.py
+++ b/dan/manager/training.py
@@ -617,7 +617,6 @@ class GenericTrainingManager:
                         batch_values, metric_names
                     )
                     batch_metrics["names"] = batch_data["names"]
-                    batch_metrics["ids"] = batch_data["ids"]
                     # Merge metrics if Distributed Data Parallel is used
                     if self.params["training_params"]["use_ddp"]:
                         batch_metrics = self.merge_ddp_metrics(batch_metrics)
@@ -762,7 +761,6 @@ class GenericTrainingManager:
                         batch_values, metric_names
                     )
                     batch_metrics["names"] = batch_data["names"]
-                    batch_metrics["ids"] = batch_data["ids"]
                     # merge metrics values if Distributed Data Parallel is used
                     if self.params["training_params"]["use_ddp"]:
                         batch_metrics = self.merge_ddp_metrics(batch_metrics)
@@ -815,7 +813,6 @@ class GenericTrainingManager:
                         batch_values, metric_names
                     )
                     batch_metrics["names"] = batch_data["names"]
-                    batch_metrics["ids"] = batch_data["ids"]
                     # merge batch metrics if Distributed Data Parallel is used
                     if self.params["training_params"]["use_ddp"]:
                         batch_metrics = self.merge_ddp_metrics(batch_metrics)
@@ -890,7 +887,6 @@ class GenericTrainingManager:
                 "edit_chars_force_len",
                 "edit_chars_curr",
                 "nb_chars_curr",
-                "ids",
             ]:
                 metrics[metric_name] = self.cat_ddp_metric(metrics[metric_name])
             elif metric_name in [
-- 
GitLab