Skip to content
Snippets Groups Projects

Clean training samples

Merged Mélodie Boillet requested to merge clean-training-samples into main
1 file
+ 0
34
Compare changes
  • Side-by-side
  • Inline
+ 0
34
@@ -204,7 +204,6 @@ class OCRDataset(GenericDataset):
def convert_sample_labels(self, sample):
label = sample["label"]
line_labels = label.split("\n")
if "remove_linebreaks" in self.params["config"]["constraints"]:
full_label = label.replace("\n", " ").replace(" ", " ")
else:
@@ -217,13 +216,6 @@ class OCRDataset(GenericDataset):
sample["label_len"] = len(sample["token_label"])
if "add_sot" in self.params["config"]["constraints"]:
sample["token_label"].insert(0, self.tokens["start"])
sample["line_label"] = line_labels
sample["token_line_label"] = [
LM_str_to_ind(self.charset, label) for label in line_labels
]
sample["line_label_len"] = [len(label) for label in line_labels]
sample["nb_lines"] = len(line_labels)
return sample
def generate_synthetic_data(self, sample):
@@ -499,28 +491,6 @@ class OCRCollateFunction:
batch_data[i]["unchanged_label"] for i in range(len(batch_data))
]
nb_lines = [batch_data[i]["nb_lines"] for i in range(len(batch_data))]
line_raw = [batch_data[i]["line_label"] for i in range(len(batch_data))]
line_token = [batch_data[i]["token_line_label"] for i in range(len(batch_data))]
pad_line_token = list()
line_len = [batch_data[i]["line_label_len"] for i in range(len(batch_data))]
for i in range(max(nb_lines)):
current_lines = [
line_token[j][i] if i < nb_lines[j] else [self.label_padding_value]
for j in range(len(batch_data))
]
pad_line_token.append(
torch.tensor(
pad_sequences_1D(
current_lines, padding_value=self.label_padding_value
)
).long()
)
for j in range(len(batch_data)):
if i >= nb_lines[j]:
line_len[j].append(0)
line_len = [i for i in zip(*line_len)]
padding_mode = (
self.config["padding_mode"] if "padding_mode" in self.config else "br"
)
@@ -540,7 +510,6 @@ class OCRCollateFunction:
formatted_batch_data = {
"names": names,
"ids": ids,
"nb_lines": nb_lines,
"labels": labels,
"reverse_labels": reverse_labels,
"raw_labels": raw_labels,
@@ -551,9 +520,6 @@ class OCRCollateFunction:
"imgs_reduced_shape": imgs_reduced_shape,
"imgs_position": imgs_position,
"imgs_reduced_position": imgs_reduced_position,
"line_raw": line_raw,
"line_labels": pad_line_token,
"line_labels_len": line_len,
}
return formatted_batch_data
Loading