diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index 42bfcadccce37e4e063306fcf2f78b5e81463559..046948067c488334cfc73599bc2ef2159963f019 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -242,7 +242,6 @@ class OCRDataset(GenericDataset): def convert_sample_labels(self, sample): label = sample["label"] - line_labels = label.split("\n") if "remove_linebreaks" in self.params["config"]["constraints"]: full_label = label.replace("\n", " ").replace(" ", " ") else: @@ -255,13 +254,6 @@ class OCRDataset(GenericDataset): sample["label_len"] = len(sample["token_label"]) if "add_sot" in self.params["config"]["constraints"]: sample["token_label"].insert(0, self.tokens["start"]) - - sample["line_label"] = line_labels - sample["token_line_label"] = [ - LM_str_to_ind(self.charset, label) for label in line_labels - ] - sample["line_label_len"] = [len(label) for label in line_labels] - sample["nb_lines"] = len(line_labels) return sample def generate_synthetic_data(self, sample): @@ -537,28 +529,6 @@ class OCRCollateFunction: batch_data[i]["unchanged_label"] for i in range(len(batch_data)) ] - nb_lines = [batch_data[i]["nb_lines"] for i in range(len(batch_data))] - line_raw = [batch_data[i]["line_label"] for i in range(len(batch_data))] - line_token = [batch_data[i]["token_line_label"] for i in range(len(batch_data))] - pad_line_token = list() - line_len = [batch_data[i]["line_label_len"] for i in range(len(batch_data))] - for i in range(max(nb_lines)): - current_lines = [ - line_token[j][i] if i < nb_lines[j] else [self.label_padding_value] - for j in range(len(batch_data)) - ] - pad_line_token.append( - torch.tensor( - pad_sequences_1D( - current_lines, padding_value=self.label_padding_value - ) - ).long() - ) - for j in range(len(batch_data)): - if i >= nb_lines[j]: - line_len[j].append(0) - line_len = [i for i in zip(*line_len)] - padding_mode = ( self.config["padding_mode"] if "padding_mode" in self.config else "br" ) @@ -578,7 +548,6 @@ class OCRCollateFunction: formatted_batch_data = { "names": names, "ids": ids, - "nb_lines": nb_lines, "labels": labels, "reverse_labels": reverse_labels, "raw_labels": raw_labels, @@ -589,9 +558,6 @@ class OCRCollateFunction: "imgs_reduced_shape": imgs_reduced_shape, "imgs_position": imgs_position, "imgs_reduced_position": imgs_reduced_position, - "line_raw": line_raw, - "line_labels": pad_line_token, - "line_labels_len": line_len, } return formatted_batch_data