Skip to content
Snippets Groups Projects

Clean training samples

Merged Mélodie Boillet requested to merge clean-training-samples into main
1 file
+ 0
35
Compare changes
  • Side-by-side
  • Inline
+ 23
128
@@ -128,9 +128,7 @@ class OCRDataset(GenericDataset):
sample = self.generate_synthetic_data(sample)
# Data augmentation
sample["img"], sample["applied_da"] = self.apply_data_augmentation(
sample["img"]
)
sample["img"] = self.apply_data_augmentation(sample["img"])
if "max_size" in self.params["config"] and self.params["config"]["max_size"]:
max_ratio = max(
@@ -148,9 +146,8 @@ class OCRDataset(GenericDataset):
if "normalize" in self.params["config"] and self.params["config"]["normalize"]:
sample["img"] = (sample["img"] - self.mean) / self.std
sample["img_shape"] = sample["img"].shape
sample["img_reduced_shape"] = np.ceil(
sample["img_shape"] / self.reduce_dims_factor
sample["img"].shape / self.reduce_dims_factor
).astype(int)
if self.set_name == "train":
@@ -159,8 +156,8 @@ class OCRDataset(GenericDataset):
]
sample["img_position"] = [
[0, sample["img_shape"][0]],
[0, sample["img_shape"][1]],
[0, sample["img"].shape[0]],
[0, sample["img"].shape[1]],
]
# Padding constraints to handle model needs
if "padding" in self.params["config"] and self.params["config"]["padding"]:
@@ -191,10 +188,6 @@ class OCRDataset(GenericDataset):
padding_mode=self.params["config"]["padding"]["mode"],
return_position=True,
)
sample["img_reduced_position"] = [
np.ceil(p / factor).astype(int)
for p, factor in zip(sample["img_position"], self.reduce_dims_factor[:2])
]
return sample
def convert_labels(self):
@@ -206,13 +199,10 @@ class OCRDataset(GenericDataset):
def convert_sample_labels(self, sample):
label = sample["label"]
line_labels = label.split("\n")
if "remove_linebreaks" in self.params["config"]["constraints"]:
full_label = label.replace("\n", " ").replace(" ", " ")
word_labels = full_label.split(" ")
else:
full_label = label
word_labels = label.replace("\n", " ").replace(" ", " ").split(" ")
sample["label"] = full_label
sample["token_label"] = LM_str_to_ind(self.charset, full_label)
@@ -221,20 +211,6 @@ class OCRDataset(GenericDataset):
sample["label_len"] = len(sample["token_label"])
if "add_sot" in self.params["config"]["constraints"]:
sample["token_label"].insert(0, self.tokens["start"])
sample["line_label"] = line_labels
sample["token_line_label"] = [
LM_str_to_ind(self.charset, label) for label in line_labels
]
sample["line_label_len"] = [len(label) for label in line_labels]
sample["nb_lines"] = len(line_labels)
sample["word_label"] = word_labels
sample["token_word_label"] = [
LM_str_to_ind(self.charset, label) for label in word_labels
]
sample["word_label_len"] = [len(label) for label in word_labels]
sample["nb_words"] = len(word_labels)
return sample
def generate_synthetic_data(self, sample):
@@ -480,121 +456,40 @@ class OCRCollateFunction:
self.config = config
def __call__(self, batch_data):
names = [batch_data[i]["name"] for i in range(len(batch_data))]
ids = [
batch_data[i]["name"].split("/")[-1].split(".")[0]
for i in range(len(batch_data))
]
applied_da = [batch_data[i]["applied_da"] for i in range(len(batch_data))]
labels = [batch_data[i]["token_label"] for i in range(len(batch_data))]
labels = pad_sequences_1D(labels, padding_value=self.label_padding_value)
labels = torch.tensor(labels).long()
reverse_labels = [
[
batch_data[i]["token_label"][0],
]
+ batch_data[i]["token_label"][-2:0:-1]
+ [
batch_data[i]["token_label"][-1],
]
for i in range(len(batch_data))
]
reverse_labels = pad_sequences_1D(
reverse_labels, padding_value=self.label_padding_value
)
reverse_labels = torch.tensor(reverse_labels).long()
labels_len = [batch_data[i]["label_len"] for i in range(len(batch_data))]
raw_labels = [batch_data[i]["label"] for i in range(len(batch_data))]
unchanged_labels = [
batch_data[i]["unchanged_label"] for i in range(len(batch_data))
]
nb_lines = [batch_data[i]["nb_lines"] for i in range(len(batch_data))]
line_raw = [batch_data[i]["line_label"] for i in range(len(batch_data))]
line_token = [batch_data[i]["token_line_label"] for i in range(len(batch_data))]
pad_line_token = list()
line_len = [batch_data[i]["line_label_len"] for i in range(len(batch_data))]
for i in range(max(nb_lines)):
current_lines = [
line_token[j][i] if i < nb_lines[j] else [self.label_padding_value]
for j in range(len(batch_data))
]
pad_line_token.append(
torch.tensor(
pad_sequences_1D(
current_lines, padding_value=self.label_padding_value
)
).long()
)
for j in range(len(batch_data)):
if i >= nb_lines[j]:
line_len[j].append(0)
line_len = [i for i in zip(*line_len)]
nb_words = [batch_data[i]["nb_words"] for i in range(len(batch_data))]
word_raw = [batch_data[i]["word_label"] for i in range(len(batch_data))]
word_token = [batch_data[i]["token_word_label"] for i in range(len(batch_data))]
pad_word_token = list()
word_len = [batch_data[i]["word_label_len"] for i in range(len(batch_data))]
for i in range(max(nb_words)):
current_words = [
word_token[j][i] if i < nb_words[j] else [self.label_padding_value]
for j in range(len(batch_data))
]
pad_word_token.append(
torch.tensor(
pad_sequences_1D(
current_words, padding_value=self.label_padding_value
)
).long()
)
for j in range(len(batch_data)):
if i >= nb_words[j]:
word_len[j].append(0)
word_len = [i for i in zip(*word_len)]
padding_mode = (
self.config["padding_mode"] if "padding_mode" in self.config else "br"
)
imgs = [batch_data[i]["img"] for i in range(len(batch_data))]
imgs_shape = [batch_data[i]["img_shape"] for i in range(len(batch_data))]
imgs_reduced_shape = [
batch_data[i]["img_reduced_shape"] for i in range(len(batch_data))
]
imgs_position = [batch_data[i]["img_position"] for i in range(len(batch_data))]
imgs_reduced_position = [
batch_data[i]["img_reduced_position"] for i in range(len(batch_data))
]
imgs = pad_images(
imgs, padding_value=self.img_padding_value, padding_mode=padding_mode
)
imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2)
formatted_batch_data = {
"names": names,
"ids": ids,
"nb_lines": nb_lines,
"labels": labels,
"reverse_labels": reverse_labels,
"raw_labels": raw_labels,
"unchanged_labels": unchanged_labels,
"labels_len": labels_len,
"imgs": imgs,
"imgs_shape": imgs_shape,
"imgs_reduced_shape": imgs_reduced_shape,
"imgs_position": imgs_position,
"imgs_reduced_position": imgs_reduced_position,
"line_raw": line_raw,
"line_labels": pad_line_token,
"line_labels_len": line_len,
"nb_words": nb_words,
"word_raw": word_raw,
"word_labels": pad_word_token,
"word_labels_len": word_len,
"applied_da": applied_da,
formatted_key: [batch_data[i][initial_key] for i in range(len(batch_data))]
for formatted_key, initial_key in zip(
[
"names",
"labels_len",
"raw_labels",
"imgs_position",
"imgs_reduced_shape",
],
["name", "label_len", "label", "img_position", "img_reduced_shape"],
)
}
formatted_batch_data.update(
{
"imgs": imgs,
"labels": labels,
}
)
return formatted_batch_data
Loading