diff --git a/dan/ocr/manager/dataset.py b/dan/ocr/manager/dataset.py index d84b4372235bbfa35eee724835aed4a38906e3b2..c07d6f42829b621752d6766e20c2539bd268486c 100644 --- a/dan/ocr/manager/dataset.py +++ b/dan/ocr/manager/dataset.py @@ -95,7 +95,7 @@ class OCRDataset(Dataset): gt_per_set = json.load(f) set_name = path_and_set["set_name"] gt = gt_per_set[set_name] - for filename in natural_sort(gt.keys()): + for filename in natural_sort(gt): if isinstance(gt[filename], dict) and "text" in gt[filename]: label = gt[filename]["text"] else: diff --git a/dan/ocr/manager/metrics.py b/dan/ocr/manager/metrics.py index f0c266752bc940247792b1d7f232daafa41e98fd..d941c21436ab557026bc2b2510f1450ee0139dbd 100644 --- a/dan/ocr/manager/metrics.py +++ b/dan/ocr/manager/metrics.py @@ -44,7 +44,7 @@ class MetricManager: for metric_name in self.metric_names: if metric_name in self.linked_metrics: for linked_metric_name in self.linked_metrics[metric_name]: - if linked_metric_name not in self.epoch_metrics.keys(): + if linked_metric_name not in self.epoch_metrics: self.epoch_metrics[linked_metric_name] = list() else: self.epoch_metrics[metric_name] = list() @@ -53,7 +53,7 @@ class MetricManager: """ Add batch metrics to the metrics """ - for key in batch_metrics.keys(): + for key in batch_metrics: if key in self.epoch_metrics: self.epoch_metrics[key] += batch_metrics[key] diff --git a/dan/ocr/manager/ocr.py b/dan/ocr/manager/ocr.py index 31079a0449a191173d15dff580a3ac274afd8f4b..20e34a1593e60699127537a90798eec6722bf06a 100644 --- a/dan/ocr/manager/ocr.py +++ b/dan/ocr/manager/ocr.py @@ -77,7 +77,7 @@ class OCRDatasetManager: self.mean, self.std = self.train_dataset.compute_std_mean() - for custom_name in self.params["val"].keys(): + for custom_name in self.params["val"]: self.valid_datasets[custom_name] = OCRDataset( set_name="val", paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]), @@ -101,7 +101,7 @@ class OCRDatasetManager: rank=self.device_params["ddp_rank"], shuffle=True, ) - for custom_name in self.valid_datasets.keys(): + for custom_name in self.valid_datasets: self.valid_samplers[custom_name] = DistributedSampler( self.valid_datasets[custom_name], num_replicas=self.device_params["nb_gpu"], @@ -109,7 +109,7 @@ class OCRDatasetManager: shuffle=False, ) else: - for custom_name in self.valid_datasets.keys(): + for custom_name in self.valid_datasets: self.valid_samplers[custom_name] = None def load_dataloaders(self): @@ -130,7 +130,7 @@ class OCRDatasetManager: generator=self.generator, ) - for key in self.valid_datasets.keys(): + for key in self.valid_datasets: self.valid_loaders[key] = DataLoader( self.valid_datasets[key], batch_size=1, @@ -158,7 +158,7 @@ class OCRDatasetManager: """ Load test dataset, data sampler and data loader """ - if custom_name in self.test_loaders.keys(): + if custom_name in self.test_loaders: return paths_and_sets = list() for set_info in sets_list: @@ -219,7 +219,7 @@ class OCRDatasetManager: return self.params["charset"] datasets = self.params["datasets"] charset = set() - for key in datasets.keys(): + for key in datasets: with open(os.path.join(datasets[key], "charset.pkl"), "rb") as f: charset = charset.union(set(pickle.load(f))) if "" in charset: diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index e14dab8abbf668efb550eb674d2580fe9221f2f7..edb906aadd053f209da1d15888afddf3299ff192 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -220,7 +220,7 @@ class GenericTrainingManager: if "scaler_state_dict" in checkpoint: self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) # Load model weights from past training - for model_name in self.models.keys(): + for model_name in self.models: # Transform to DDP/from DDP model checkpoint[f"{model_name}_state_dict"] = fix_ddp_layers_names( checkpoint[f"{model_name}_state_dict"], @@ -236,7 +236,7 @@ class GenericTrainingManager: Initialize model """ # Specific weights initialization if exists - for model_name in self.models.keys(): + for model_name in self.models: try: self.models[model_name].init_weights() except Exception: @@ -338,7 +338,7 @@ class GenericTrainingManager: """ Load the optimizer of each model """ - for model_name in self.models.keys(): + for model_name in self.models: if ( checkpoint and "optimizer_named_params_{}".format(model_name) in checkpoint @@ -381,8 +381,7 @@ class GenericTrainingManager: if ( "lr_schedulers" in self.params["training"] and self.params["training"]["lr_schedulers"] - and "lr_scheduler_{}_state_dict".format(model_name) - in checkpoint.keys() + and "lr_scheduler_{}_state_dict".format(model_name) in checkpoint ): self.lr_schedulers[model_name].load_state_dict( checkpoint["lr_scheduler_{}_state_dict".format(model_name)] @@ -421,7 +420,7 @@ class GenericTrainingManager: "lr_scheduler_{}_state_dict".format(model_name) ] = self.lr_schedulers[model_name].state_dict() content = self.add_save_info(content) - for model_name in self.models.keys(): + for model_name in self.models: content["{}_state_dict".format(model_name)] = self.models[ model_name ].state_dict() @@ -575,7 +574,7 @@ class GenericTrainingManager: # perform epochs for num_epoch in range(self.latest_epoch + 1, nb_epochs): # set models trainable - for model_name in self.models.keys(): + for model_name in self.models: self.models[model_name].train() self.latest_epoch = num_epoch if self.dataset.train_dataset.curriculum_config: @@ -642,7 +641,7 @@ class GenericTrainingManager: if self.is_master: # log metrics in tensorboard file - for key in display_values.keys(): + for key in display_values: self.writer.add_scalar( "{}_{}".format(self.params["dataset"]["train"]["name"], key), display_values[key], @@ -656,14 +655,14 @@ class GenericTrainingManager: % self.params["training"]["validation"]["eval_on_valid_interval"] == 0 ): - for valid_set_name in self.dataset.valid_loaders.keys(): + for valid_set_name in self.dataset.valid_loaders: # evaluate set and compute metrics eval_values = self.evaluate( valid_set_name, mlflow_logging=mlflow_logging ) # log valid metrics in tensorboard file if self.is_master: - for key in eval_values.keys(): + for key in eval_values: self.writer.add_scalar( "{}_{}".format(valid_set_name, key), eval_values[key], @@ -696,7 +695,7 @@ class GenericTrainingManager: """ loader = self.dataset.valid_loaders[set_name] # Set models in eval mode - for model_name in self.models.keys(): + for model_name in self.models: self.models[model_name].eval() metric_names = self.params["training"]["metrics"]["eval"] display_values = None @@ -749,7 +748,7 @@ class GenericTrainingManager: self.dataset.generate_test_loader(custom_name, sets_list) loader = self.dataset.test_loaders[custom_name] # Set models in eval mode - for model_name in self.models.keys(): + for model_name in self.models: self.models[model_name].eval() # initialize epoch metrics @@ -835,7 +834,7 @@ class GenericTrainingManager: """ Merge metrics when Distributed Data Parallel is used """ - for metric_name in metrics.keys(): + for metric_name in metrics: if metric_name in [ "edit_words", "nb_words", @@ -891,7 +890,7 @@ class GenericTrainingManager: """ Load curriculum info from saved model info """ - if "curriculum_config" in info_dict.keys(): + if "curriculum_config" in info_dict: self.dataset.train_dataset.curriculum_config = info_dict[ "curriculum_config" ] @@ -915,7 +914,7 @@ class Manager(OCRManager): super(Manager, self).__init__(params) def load_save_info(self, info_dict): - if "curriculum_config" in info_dict.keys(): + if "curriculum_config" in info_dict: if self.dataset.train_dataset is not None: self.dataset.train_dataset.curriculum_config = info_dict[ "curriculum_config" diff --git a/dan/ocr/schedulers.py b/dan/ocr/schedulers.py index 3c6ef2389766f646647887bc8692abbf2012467c..bdabe62980244dfadf11df22147fa239d9fa5d53 100644 --- a/dan/ocr/schedulers.py +++ b/dan/ocr/schedulers.py @@ -18,7 +18,7 @@ class DropoutScheduler: self.step_num += num def init_teta_list(self, models): - for model_name in models.keys(): + for model_name in models: self.init_teta_list_module(models[model_name]) def init_teta_list_module(self, module): diff --git a/dan/ocr/train.py b/dan/ocr/train.py index 1469f19b3bcca8323b35ae494751537b9453b725..001b4cd27a34de3bc60cbce4ee5db9e1b547fac1 100644 --- a/dan/ocr/train.py +++ b/dan/ocr/train.py @@ -48,7 +48,7 @@ def train_and_test(rank, params, mlflow_logging=False): model.load_model() metrics = ["cer", "wer", "wer_no_punct", "time"] - for dataset_name in params["dataset"]["datasets"].keys(): + for dataset_name in params["dataset"]["datasets"]: for set_name in ["test", "val", "train"]: model.predict( "{}-{}".format(dataset_name, set_name),