Skip to content
Snippets Groups Projects

Clean training samples

Merged Mélodie Boillet requested to merge clean-training-samples into main
1 file
+ 0
35
Compare changes
  • Side-by-side
  • Inline
+ 0
35
@@ -207,10 +207,8 @@ class OCRDataset(GenericDataset):
line_labels = label.split("\n")
if "remove_linebreaks" in self.params["config"]["constraints"]:
full_label = label.replace("\n", " ").replace(" ", " ")
word_labels = full_label.split(" ")
else:
full_label = label
word_labels = label.replace("\n", " ").replace(" ", " ").split(" ")
sample["label"] = full_label
sample["token_label"] = LM_str_to_ind(self.charset, full_label)
@@ -226,13 +224,6 @@ class OCRDataset(GenericDataset):
]
sample["line_label_len"] = [len(label) for label in line_labels]
sample["nb_lines"] = len(line_labels)
sample["word_label"] = word_labels
sample["token_word_label"] = [
LM_str_to_ind(self.charset, label) for label in word_labels
]
sample["word_label_len"] = [len(label) for label in word_labels]
sample["nb_words"] = len(word_labels)
return sample
def generate_synthetic_data(self, sample):
@@ -530,28 +521,6 @@ class OCRCollateFunction:
line_len[j].append(0)
line_len = [i for i in zip(*line_len)]
nb_words = [batch_data[i]["nb_words"] for i in range(len(batch_data))]
word_raw = [batch_data[i]["word_label"] for i in range(len(batch_data))]
word_token = [batch_data[i]["token_word_label"] for i in range(len(batch_data))]
pad_word_token = list()
word_len = [batch_data[i]["word_label_len"] for i in range(len(batch_data))]
for i in range(max(nb_words)):
current_words = [
word_token[j][i] if i < nb_words[j] else [self.label_padding_value]
for j in range(len(batch_data))
]
pad_word_token.append(
torch.tensor(
pad_sequences_1D(
current_words, padding_value=self.label_padding_value
)
).long()
)
for j in range(len(batch_data)):
if i >= nb_words[j]:
word_len[j].append(0)
word_len = [i for i in zip(*word_len)]
padding_mode = (
self.config["padding_mode"] if "padding_mode" in self.config else "br"
)
@@ -585,10 +554,6 @@ class OCRCollateFunction:
"line_raw": line_raw,
"line_labels": pad_line_token,
"line_labels_len": line_len,
"nb_words": nb_words,
"word_raw": word_raw,
"word_labels": pad_word_token,
"word_labels_len": word_len,
}
return formatted_batch_data
Loading