From 7136907a1e394a7ae3eea71c3bfdc22e82789d7e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com>
Date: Wed, 24 May 2023 14:36:43 +0200
Subject: [PATCH] Remove words information

---
 dan/manager/ocr.py | 35 -----------------------------------
 1 file changed, 35 deletions(-)

diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py
index 2e7ab0c4..32c798f5 100644
--- a/dan/manager/ocr.py
+++ b/dan/manager/ocr.py
@@ -207,10 +207,8 @@ class OCRDataset(GenericDataset):
         line_labels = label.split("\n")
         if "remove_linebreaks" in self.params["config"]["constraints"]:
             full_label = label.replace("\n", " ").replace("  ", " ")
-            word_labels = full_label.split(" ")
         else:
             full_label = label
-            word_labels = label.replace("\n", " ").replace("  ", " ").split(" ")
 
         sample["label"] = full_label
         sample["token_label"] = LM_str_to_ind(self.charset, full_label)
@@ -226,13 +224,6 @@ class OCRDataset(GenericDataset):
         ]
         sample["line_label_len"] = [len(label) for label in line_labels]
         sample["nb_lines"] = len(line_labels)
-
-        sample["word_label"] = word_labels
-        sample["token_word_label"] = [
-            LM_str_to_ind(self.charset, label) for label in word_labels
-        ]
-        sample["word_label_len"] = [len(label) for label in word_labels]
-        sample["nb_words"] = len(word_labels)
         return sample
 
     def generate_synthetic_data(self, sample):
@@ -530,28 +521,6 @@ class OCRCollateFunction:
                     line_len[j].append(0)
         line_len = [i for i in zip(*line_len)]
 
-        nb_words = [batch_data[i]["nb_words"] for i in range(len(batch_data))]
-        word_raw = [batch_data[i]["word_label"] for i in range(len(batch_data))]
-        word_token = [batch_data[i]["token_word_label"] for i in range(len(batch_data))]
-        pad_word_token = list()
-        word_len = [batch_data[i]["word_label_len"] for i in range(len(batch_data))]
-        for i in range(max(nb_words)):
-            current_words = [
-                word_token[j][i] if i < nb_words[j] else [self.label_padding_value]
-                for j in range(len(batch_data))
-            ]
-            pad_word_token.append(
-                torch.tensor(
-                    pad_sequences_1D(
-                        current_words, padding_value=self.label_padding_value
-                    )
-                ).long()
-            )
-            for j in range(len(batch_data)):
-                if i >= nb_words[j]:
-                    word_len[j].append(0)
-        word_len = [i for i in zip(*word_len)]
-
         padding_mode = (
             self.config["padding_mode"] if "padding_mode" in self.config else "br"
         )
@@ -585,10 +554,6 @@ class OCRCollateFunction:
             "line_raw": line_raw,
             "line_labels": pad_line_token,
             "line_labels_len": line_len,
-            "nb_words": nb_words,
-            "word_raw": word_raw,
-            "word_labels": pad_word_token,
-            "word_labels_len": word_len,
         }
 
         return formatted_batch_data
-- 
GitLab