From 7136907a1e394a7ae3eea71c3bfdc22e82789d7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com> Date: Wed, 24 May 2023 14:36:43 +0200 Subject: [PATCH] Remove words information --- dan/manager/ocr.py | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index 2e7ab0c4..32c798f5 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -207,10 +207,8 @@ class OCRDataset(GenericDataset): line_labels = label.split("\n") if "remove_linebreaks" in self.params["config"]["constraints"]: full_label = label.replace("\n", " ").replace(" ", " ") - word_labels = full_label.split(" ") else: full_label = label - word_labels = label.replace("\n", " ").replace(" ", " ").split(" ") sample["label"] = full_label sample["token_label"] = LM_str_to_ind(self.charset, full_label) @@ -226,13 +224,6 @@ class OCRDataset(GenericDataset): ] sample["line_label_len"] = [len(label) for label in line_labels] sample["nb_lines"] = len(line_labels) - - sample["word_label"] = word_labels - sample["token_word_label"] = [ - LM_str_to_ind(self.charset, label) for label in word_labels - ] - sample["word_label_len"] = [len(label) for label in word_labels] - sample["nb_words"] = len(word_labels) return sample def generate_synthetic_data(self, sample): @@ -530,28 +521,6 @@ class OCRCollateFunction: line_len[j].append(0) line_len = [i for i in zip(*line_len)] - nb_words = [batch_data[i]["nb_words"] for i in range(len(batch_data))] - word_raw = [batch_data[i]["word_label"] for i in range(len(batch_data))] - word_token = [batch_data[i]["token_word_label"] for i in range(len(batch_data))] - pad_word_token = list() - word_len = [batch_data[i]["word_label_len"] for i in range(len(batch_data))] - for i in range(max(nb_words)): - current_words = [ - word_token[j][i] if i < nb_words[j] else [self.label_padding_value] - for j in range(len(batch_data)) - ] - pad_word_token.append( - torch.tensor( - pad_sequences_1D( - current_words, padding_value=self.label_padding_value - ) - ).long() - ) - for j in range(len(batch_data)): - if i >= nb_words[j]: - word_len[j].append(0) - word_len = [i for i in zip(*word_len)] - padding_mode = ( self.config["padding_mode"] if "padding_mode" in self.config else "br" ) @@ -585,10 +554,6 @@ class OCRCollateFunction: "line_raw": line_raw, "line_labels": pad_line_token, "line_labels_len": line_len, - "nb_words": nb_words, - "word_raw": word_raw, - "word_labels": pad_word_token, - "word_labels_len": word_len, } return formatted_batch_data -- GitLab