Skip to content
Snippets Groups Projects
Commit 31b40e73 authored by Mélodie Boillet's avatar Mélodie Boillet Committed by Yoann Schneider
Browse files

Remove lines information

parent 40a8dfb7
No related branches found
No related tags found
1 merge request!137Clean training samples
This commit is part of merge request !137. Comments created here will be created in the context of that merge request.
......@@ -204,7 +204,6 @@ class OCRDataset(GenericDataset):
def convert_sample_labels(self, sample):
label = sample["label"]
line_labels = label.split("\n")
if "remove_linebreaks" in self.params["config"]["constraints"]:
full_label = label.replace("\n", " ").replace(" ", " ")
else:
......@@ -217,13 +216,6 @@ class OCRDataset(GenericDataset):
sample["label_len"] = len(sample["token_label"])
if "add_sot" in self.params["config"]["constraints"]:
sample["token_label"].insert(0, self.tokens["start"])
sample["line_label"] = line_labels
sample["token_line_label"] = [
LM_str_to_ind(self.charset, label) for label in line_labels
]
sample["line_label_len"] = [len(label) for label in line_labels]
sample["nb_lines"] = len(line_labels)
return sample
def generate_synthetic_data(self, sample):
......@@ -499,28 +491,6 @@ class OCRCollateFunction:
batch_data[i]["unchanged_label"] for i in range(len(batch_data))
]
nb_lines = [batch_data[i]["nb_lines"] for i in range(len(batch_data))]
line_raw = [batch_data[i]["line_label"] for i in range(len(batch_data))]
line_token = [batch_data[i]["token_line_label"] for i in range(len(batch_data))]
pad_line_token = list()
line_len = [batch_data[i]["line_label_len"] for i in range(len(batch_data))]
for i in range(max(nb_lines)):
current_lines = [
line_token[j][i] if i < nb_lines[j] else [self.label_padding_value]
for j in range(len(batch_data))
]
pad_line_token.append(
torch.tensor(
pad_sequences_1D(
current_lines, padding_value=self.label_padding_value
)
).long()
)
for j in range(len(batch_data)):
if i >= nb_lines[j]:
line_len[j].append(0)
line_len = [i for i in zip(*line_len)]
padding_mode = (
self.config["padding_mode"] if "padding_mode" in self.config else "br"
)
......@@ -540,7 +510,6 @@ class OCRCollateFunction:
formatted_batch_data = {
"names": names,
"ids": ids,
"nb_lines": nb_lines,
"labels": labels,
"reverse_labels": reverse_labels,
"raw_labels": raw_labels,
......@@ -551,9 +520,6 @@ class OCRCollateFunction:
"imgs_reduced_shape": imgs_reduced_shape,
"imgs_position": imgs_position,
"imgs_reduced_position": imgs_reduced_position,
"line_raw": line_raw,
"line_labels": pad_line_token,
"line_labels_len": line_len,
}
return formatted_batch_data
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment