From 6c5ed2e4929de10b418411f7686020d595019f64 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Wed, 24 May 2023 14:45:11 +0200
Subject: [PATCH] Remove lines information

---
 dan/manager/ocr.py | 34 ----------------------------------
 1 file changed, 34 deletions(-)

diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index 42bfcadc..04694806 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -242,7 +242,6 @@ class OCRDataset(GenericDataset):
 
     def convert_sample_labels(self, sample):
         label = sample["label"]
-        line_labels = label.split("\n")
         if "remove_linebreaks" in self.params["config"]["constraints"]:
             full_label = label.replace("\n", " ").replace("  ", " ")
         else:
@@ -255,13 +254,6 @@ class OCRDataset(GenericDataset):
         sample["label_len"] = len(sample["token_label"])
         if "add_sot" in self.params["config"]["constraints"]:
             sample["token_label"].insert(0, self.tokens["start"])
-
-        sample["line_label"] = line_labels
-        sample["token_line_label"] = [
-            LM_str_to_ind(self.charset, label) for label in line_labels
-        ]
-        sample["line_label_len"] = [len(label) for label in line_labels]
-        sample["nb_lines"] = len(line_labels)
         return sample
 
     def generate_synthetic_data(self, sample):
@@ -537,28 +529,6 @@ class OCRCollateFunction:
             batch_data[i]["unchanged_label"] for i in range(len(batch_data))
         ]
 
-        nb_lines = [batch_data[i]["nb_lines"] for i in range(len(batch_data))]
-        line_raw = [batch_data[i]["line_label"] for i in range(len(batch_data))]
-        line_token = [batch_data[i]["token_line_label"] for i in range(len(batch_data))]
-        pad_line_token = list()
-        line_len = [batch_data[i]["line_label_len"] for i in range(len(batch_data))]
-        for i in range(max(nb_lines)):
-            current_lines = [
-                line_token[j][i] if i < nb_lines[j] else [self.label_padding_value]
-                for j in range(len(batch_data))
-            ]
-            pad_line_token.append(
-                torch.tensor(
-                    pad_sequences_1D(
-                        current_lines, padding_value=self.label_padding_value
-                    )
-                ).long()
-            )
-            for j in range(len(batch_data)):
-                if i >= nb_lines[j]:
-                    line_len[j].append(0)
-        line_len = [i for i in zip(*line_len)]
-
         padding_mode = (
             self.config["padding_mode"] if "padding_mode" in self.config else "br"
         )
@@ -578,7 +548,6 @@ class OCRCollateFunction:
         formatted_batch_data = {
             "names": names,
             "ids": ids,
-            "nb_lines": nb_lines,
             "labels": labels,
             "reverse_labels": reverse_labels,
             "raw_labels": raw_labels,
@@ -589,9 +558,6 @@ class OCRCollateFunction:
             "imgs_reduced_shape": imgs_reduced_shape,
             "imgs_position": imgs_position,
             "imgs_reduced_position": imgs_reduced_position,
-            "line_raw": line_raw,
-            "line_labels": pad_line_token,
-            "line_labels_len": line_len,
         }
 
         return formatted_batch_data
-- 
GitLab