Skip to content
Snippets Groups Projects
Verified Commit 7136907a authored by Mélodie Boillet's avatar Mélodie Boillet
Browse files

Remove words information

parent 21324e0e
No related branches found
No related tags found
No related merge requests found
...@@ -207,10 +207,8 @@ class OCRDataset(GenericDataset): ...@@ -207,10 +207,8 @@ class OCRDataset(GenericDataset):
line_labels = label.split("\n") line_labels = label.split("\n")
if "remove_linebreaks" in self.params["config"]["constraints"]: if "remove_linebreaks" in self.params["config"]["constraints"]:
full_label = label.replace("\n", " ").replace(" ", " ") full_label = label.replace("\n", " ").replace(" ", " ")
word_labels = full_label.split(" ")
else: else:
full_label = label full_label = label
word_labels = label.replace("\n", " ").replace(" ", " ").split(" ")
sample["label"] = full_label sample["label"] = full_label
sample["token_label"] = LM_str_to_ind(self.charset, full_label) sample["token_label"] = LM_str_to_ind(self.charset, full_label)
...@@ -226,13 +224,6 @@ class OCRDataset(GenericDataset): ...@@ -226,13 +224,6 @@ class OCRDataset(GenericDataset):
] ]
sample["line_label_len"] = [len(label) for label in line_labels] sample["line_label_len"] = [len(label) for label in line_labels]
sample["nb_lines"] = len(line_labels) sample["nb_lines"] = len(line_labels)
sample["word_label"] = word_labels
sample["token_word_label"] = [
LM_str_to_ind(self.charset, label) for label in word_labels
]
sample["word_label_len"] = [len(label) for label in word_labels]
sample["nb_words"] = len(word_labels)
return sample return sample
def generate_synthetic_data(self, sample): def generate_synthetic_data(self, sample):
...@@ -530,28 +521,6 @@ class OCRCollateFunction: ...@@ -530,28 +521,6 @@ class OCRCollateFunction:
line_len[j].append(0) line_len[j].append(0)
line_len = [i for i in zip(*line_len)] line_len = [i for i in zip(*line_len)]
nb_words = [batch_data[i]["nb_words"] for i in range(len(batch_data))]
word_raw = [batch_data[i]["word_label"] for i in range(len(batch_data))]
word_token = [batch_data[i]["token_word_label"] for i in range(len(batch_data))]
pad_word_token = list()
word_len = [batch_data[i]["word_label_len"] for i in range(len(batch_data))]
for i in range(max(nb_words)):
current_words = [
word_token[j][i] if i < nb_words[j] else [self.label_padding_value]
for j in range(len(batch_data))
]
pad_word_token.append(
torch.tensor(
pad_sequences_1D(
current_words, padding_value=self.label_padding_value
)
).long()
)
for j in range(len(batch_data)):
if i >= nb_words[j]:
word_len[j].append(0)
word_len = [i for i in zip(*word_len)]
padding_mode = ( padding_mode = (
self.config["padding_mode"] if "padding_mode" in self.config else "br" self.config["padding_mode"] if "padding_mode" in self.config else "br"
) )
...@@ -585,10 +554,6 @@ class OCRCollateFunction: ...@@ -585,10 +554,6 @@ class OCRCollateFunction:
"line_raw": line_raw, "line_raw": line_raw,
"line_labels": pad_line_token, "line_labels": pad_line_token,
"line_labels_len": line_len, "line_labels_len": line_len,
"nb_words": nb_words,
"word_raw": word_raw,
"word_labels": pad_word_token,
"word_labels_len": word_len,
} }
return formatted_batch_data return formatted_batch_data
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment