From be8b1a667f84c1161f57826ab15b72046cf24a2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9lodie=20Boillet?= <boillet@teklia.com> Date: Wed, 24 May 2023 15:08:25 +0200 Subject: [PATCH] Remove ids + clean code --- dan/manager/metrics.py | 1 - dan/manager/ocr.py | 76 ++++++++++++----------------------------- dan/manager/training.py | 4 --- 3 files changed, 22 insertions(+), 59 deletions(-) diff --git a/dan/manager/metrics.py b/dan/manager/metrics.py index abb8c166..83aaa78c 100644 --- a/dan/manager/metrics.py +++ b/dan/manager/metrics.py @@ -55,7 +55,6 @@ class MetricManager: self.epoch_metrics = { "nb_samples": list(), "names": list(), - "ids": list(), } for metric_name in self.metric_names: diff --git a/dan/manager/ocr.py b/dan/manager/ocr.py index a3c44c82..2132ff62 100644 --- a/dan/manager/ocr.py +++ b/dan/manager/ocr.py @@ -146,9 +146,8 @@ class OCRDataset(GenericDataset): if "normalize" in self.params["config"] and self.params["config"]["normalize"]: sample["img"] = (sample["img"] - self.mean) / self.std - sample["img_shape"] = sample["img"].shape sample["img_reduced_shape"] = np.ceil( - sample["img_shape"] / self.reduce_dims_factor + sample["img"].shape / self.reduce_dims_factor ).astype(int) if self.set_name == "train": @@ -157,8 +156,8 @@ class OCRDataset(GenericDataset): ] sample["img_position"] = [ - [0, sample["img_shape"][0]], - [0, sample["img_shape"][1]], + [0, sample["img"].shape[0]], + [0, sample["img"].shape[1]], ] # Padding constraints to handle model needs if "padding" in self.params["config"] and self.params["config"]["padding"]: @@ -189,10 +188,6 @@ class OCRDataset(GenericDataset): padding_mode=self.params["config"]["padding"]["mode"], return_position=True, ) - sample["img_reduced_position"] = [ - np.ceil(p / factor).astype(int) - for p, factor in zip(sample["img_position"], self.reduce_dims_factor[:2]) - ] return sample def convert_labels(self): @@ -461,67 +456,40 @@ class OCRCollateFunction: self.config = config def __call__(self, batch_data): - names = [batch_data[i]["name"] for i in range(len(batch_data))] - ids = [ - batch_data[i]["name"].split("/")[-1].split(".")[0] - for i in range(len(batch_data)) - ] - labels = [batch_data[i]["token_label"] for i in range(len(batch_data))] labels = pad_sequences_1D(labels, padding_value=self.label_padding_value) labels = torch.tensor(labels).long() - reverse_labels = [ - [ - batch_data[i]["token_label"][0], - ] - + batch_data[i]["token_label"][-2:0:-1] - + [ - batch_data[i]["token_label"][-1], - ] - for i in range(len(batch_data)) - ] - reverse_labels = pad_sequences_1D( - reverse_labels, padding_value=self.label_padding_value - ) - reverse_labels = torch.tensor(reverse_labels).long() - labels_len = [batch_data[i]["label_len"] for i in range(len(batch_data))] - - raw_labels = [batch_data[i]["label"] for i in range(len(batch_data))] - unchanged_labels = [ - batch_data[i]["unchanged_label"] for i in range(len(batch_data)) - ] padding_mode = ( self.config["padding_mode"] if "padding_mode" in self.config else "br" ) imgs = [batch_data[i]["img"] for i in range(len(batch_data))] - imgs_shape = [batch_data[i]["img_shape"] for i in range(len(batch_data))] - imgs_reduced_shape = [ - batch_data[i]["img_reduced_shape"] for i in range(len(batch_data)) - ] - imgs_position = [batch_data[i]["img_position"] for i in range(len(batch_data))] - imgs_reduced_position = [ - batch_data[i]["img_reduced_position"] for i in range(len(batch_data)) - ] imgs = pad_images( imgs, padding_value=self.img_padding_value, padding_mode=padding_mode ) imgs = torch.tensor(imgs).float().permute(0, 3, 1, 2) + formatted_batch_data = { - "names": names, - "ids": ids, - "labels": labels, - "reverse_labels": reverse_labels, - "raw_labels": raw_labels, - "unchanged_labels": unchanged_labels, - "labels_len": labels_len, - "imgs": imgs, - "imgs_shape": imgs_shape, - "imgs_reduced_shape": imgs_reduced_shape, - "imgs_position": imgs_position, - "imgs_reduced_position": imgs_reduced_position, + formatted_key: [batch_data[i][initial_key] for i in range(len(batch_data))] + for formatted_key, initial_key in zip( + [ + "names", + "labels_len", + "raw_labels", + "imgs_position", + "imgs_reduced_shape", + ], + ["name", "label_len", "label", "img_position", "img_reduced_shape"], + ) } + formatted_batch_data.update( + { + "imgs": imgs, + "labels": labels, + } + ) + return formatted_batch_data diff --git a/dan/manager/training.py b/dan/manager/training.py index 53012a7c..25855e1b 100644 --- a/dan/manager/training.py +++ b/dan/manager/training.py @@ -617,7 +617,6 @@ class GenericTrainingManager: batch_values, metric_names ) batch_metrics["names"] = batch_data["names"] - batch_metrics["ids"] = batch_data["ids"] # Merge metrics if Distributed Data Parallel is used if self.params["training_params"]["use_ddp"]: batch_metrics = self.merge_ddp_metrics(batch_metrics) @@ -762,7 +761,6 @@ class GenericTrainingManager: batch_values, metric_names ) batch_metrics["names"] = batch_data["names"] - batch_metrics["ids"] = batch_data["ids"] # merge metrics values if Distributed Data Parallel is used if self.params["training_params"]["use_ddp"]: batch_metrics = self.merge_ddp_metrics(batch_metrics) @@ -815,7 +813,6 @@ class GenericTrainingManager: batch_values, metric_names ) batch_metrics["names"] = batch_data["names"] - batch_metrics["ids"] = batch_data["ids"] # merge batch metrics if Distributed Data Parallel is used if self.params["training_params"]["use_ddp"]: batch_metrics = self.merge_ddp_metrics(batch_metrics) @@ -890,7 +887,6 @@ class GenericTrainingManager: "edit_chars_force_len", "edit_chars_curr", "nb_chars_curr", - "ids", ]: metrics[metric_name] = self.cat_ddp_metric(metrics[metric_name]) elif metric_name in [ -- GitLab