Skip to content
Snippets Groups Projects
Commit 1dcb6c01 authored by Yoann Schneider's avatar Yoann Schneider :tennis: Committed by Manon Blanco
Browse files

Iterate over the dict directly

parent 6763a64b
No related branches found
No related tags found
1 merge request!246Iterate over the dict directly
...@@ -95,7 +95,7 @@ class OCRDataset(Dataset): ...@@ -95,7 +95,7 @@ class OCRDataset(Dataset):
gt_per_set = json.load(f) gt_per_set = json.load(f)
set_name = path_and_set["set_name"] set_name = path_and_set["set_name"]
gt = gt_per_set[set_name] gt = gt_per_set[set_name]
for filename in natural_sort(gt.keys()): for filename in natural_sort(gt):
if isinstance(gt[filename], dict) and "text" in gt[filename]: if isinstance(gt[filename], dict) and "text" in gt[filename]:
label = gt[filename]["text"] label = gt[filename]["text"]
else: else:
......
...@@ -44,7 +44,7 @@ class MetricManager: ...@@ -44,7 +44,7 @@ class MetricManager:
for metric_name in self.metric_names: for metric_name in self.metric_names:
if metric_name in self.linked_metrics: if metric_name in self.linked_metrics:
for linked_metric_name in self.linked_metrics[metric_name]: for linked_metric_name in self.linked_metrics[metric_name]:
if linked_metric_name not in self.epoch_metrics.keys(): if linked_metric_name not in self.epoch_metrics:
self.epoch_metrics[linked_metric_name] = list() self.epoch_metrics[linked_metric_name] = list()
else: else:
self.epoch_metrics[metric_name] = list() self.epoch_metrics[metric_name] = list()
...@@ -53,7 +53,7 @@ class MetricManager: ...@@ -53,7 +53,7 @@ class MetricManager:
""" """
Add batch metrics to the metrics Add batch metrics to the metrics
""" """
for key in batch_metrics.keys(): for key in batch_metrics:
if key in self.epoch_metrics: if key in self.epoch_metrics:
self.epoch_metrics[key] += batch_metrics[key] self.epoch_metrics[key] += batch_metrics[key]
......
...@@ -77,7 +77,7 @@ class OCRDatasetManager: ...@@ -77,7 +77,7 @@ class OCRDatasetManager:
self.mean, self.std = self.train_dataset.compute_std_mean() self.mean, self.std = self.train_dataset.compute_std_mean()
for custom_name in self.params["val"].keys(): for custom_name in self.params["val"]:
self.valid_datasets[custom_name] = OCRDataset( self.valid_datasets[custom_name] = OCRDataset(
set_name="val", set_name="val",
paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]), paths_and_sets=self.get_paths_and_sets(self.params["val"][custom_name]),
...@@ -101,7 +101,7 @@ class OCRDatasetManager: ...@@ -101,7 +101,7 @@ class OCRDatasetManager:
rank=self.device_params["ddp_rank"], rank=self.device_params["ddp_rank"],
shuffle=True, shuffle=True,
) )
for custom_name in self.valid_datasets.keys(): for custom_name in self.valid_datasets:
self.valid_samplers[custom_name] = DistributedSampler( self.valid_samplers[custom_name] = DistributedSampler(
self.valid_datasets[custom_name], self.valid_datasets[custom_name],
num_replicas=self.device_params["nb_gpu"], num_replicas=self.device_params["nb_gpu"],
...@@ -109,7 +109,7 @@ class OCRDatasetManager: ...@@ -109,7 +109,7 @@ class OCRDatasetManager:
shuffle=False, shuffle=False,
) )
else: else:
for custom_name in self.valid_datasets.keys(): for custom_name in self.valid_datasets:
self.valid_samplers[custom_name] = None self.valid_samplers[custom_name] = None
def load_dataloaders(self): def load_dataloaders(self):
...@@ -130,7 +130,7 @@ class OCRDatasetManager: ...@@ -130,7 +130,7 @@ class OCRDatasetManager:
generator=self.generator, generator=self.generator,
) )
for key in self.valid_datasets.keys(): for key in self.valid_datasets:
self.valid_loaders[key] = DataLoader( self.valid_loaders[key] = DataLoader(
self.valid_datasets[key], self.valid_datasets[key],
batch_size=1, batch_size=1,
...@@ -158,7 +158,7 @@ class OCRDatasetManager: ...@@ -158,7 +158,7 @@ class OCRDatasetManager:
""" """
Load test dataset, data sampler and data loader Load test dataset, data sampler and data loader
""" """
if custom_name in self.test_loaders.keys(): if custom_name in self.test_loaders:
return return
paths_and_sets = list() paths_and_sets = list()
for set_info in sets_list: for set_info in sets_list:
...@@ -219,7 +219,7 @@ class OCRDatasetManager: ...@@ -219,7 +219,7 @@ class OCRDatasetManager:
return self.params["charset"] return self.params["charset"]
datasets = self.params["datasets"] datasets = self.params["datasets"]
charset = set() charset = set()
for key in datasets.keys(): for key in datasets:
with open(os.path.join(datasets[key], "charset.pkl"), "rb") as f: with open(os.path.join(datasets[key], "charset.pkl"), "rb") as f:
charset = charset.union(set(pickle.load(f))) charset = charset.union(set(pickle.load(f)))
if "" in charset: if "" in charset:
......
...@@ -220,7 +220,7 @@ class GenericTrainingManager: ...@@ -220,7 +220,7 @@ class GenericTrainingManager:
if "scaler_state_dict" in checkpoint: if "scaler_state_dict" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
# Load model weights from past training # Load model weights from past training
for model_name in self.models.keys(): for model_name in self.models:
# Transform to DDP/from DDP model # Transform to DDP/from DDP model
checkpoint[f"{model_name}_state_dict"] = fix_ddp_layers_names( checkpoint[f"{model_name}_state_dict"] = fix_ddp_layers_names(
checkpoint[f"{model_name}_state_dict"], checkpoint[f"{model_name}_state_dict"],
...@@ -236,7 +236,7 @@ class GenericTrainingManager: ...@@ -236,7 +236,7 @@ class GenericTrainingManager:
Initialize model Initialize model
""" """
# Specific weights initialization if exists # Specific weights initialization if exists
for model_name in self.models.keys(): for model_name in self.models:
try: try:
self.models[model_name].init_weights() self.models[model_name].init_weights()
except Exception: except Exception:
...@@ -338,7 +338,7 @@ class GenericTrainingManager: ...@@ -338,7 +338,7 @@ class GenericTrainingManager:
""" """
Load the optimizer of each model Load the optimizer of each model
""" """
for model_name in self.models.keys(): for model_name in self.models:
if ( if (
checkpoint checkpoint
and "optimizer_named_params_{}".format(model_name) in checkpoint and "optimizer_named_params_{}".format(model_name) in checkpoint
...@@ -381,8 +381,7 @@ class GenericTrainingManager: ...@@ -381,8 +381,7 @@ class GenericTrainingManager:
if ( if (
"lr_schedulers" in self.params["training"] "lr_schedulers" in self.params["training"]
and self.params["training"]["lr_schedulers"] and self.params["training"]["lr_schedulers"]
and "lr_scheduler_{}_state_dict".format(model_name) and "lr_scheduler_{}_state_dict".format(model_name) in checkpoint
in checkpoint.keys()
): ):
self.lr_schedulers[model_name].load_state_dict( self.lr_schedulers[model_name].load_state_dict(
checkpoint["lr_scheduler_{}_state_dict".format(model_name)] checkpoint["lr_scheduler_{}_state_dict".format(model_name)]
...@@ -421,7 +420,7 @@ class GenericTrainingManager: ...@@ -421,7 +420,7 @@ class GenericTrainingManager:
"lr_scheduler_{}_state_dict".format(model_name) "lr_scheduler_{}_state_dict".format(model_name)
] = self.lr_schedulers[model_name].state_dict() ] = self.lr_schedulers[model_name].state_dict()
content = self.add_save_info(content) content = self.add_save_info(content)
for model_name in self.models.keys(): for model_name in self.models:
content["{}_state_dict".format(model_name)] = self.models[ content["{}_state_dict".format(model_name)] = self.models[
model_name model_name
].state_dict() ].state_dict()
...@@ -575,7 +574,7 @@ class GenericTrainingManager: ...@@ -575,7 +574,7 @@ class GenericTrainingManager:
# perform epochs # perform epochs
for num_epoch in range(self.latest_epoch + 1, nb_epochs): for num_epoch in range(self.latest_epoch + 1, nb_epochs):
# set models trainable # set models trainable
for model_name in self.models.keys(): for model_name in self.models:
self.models[model_name].train() self.models[model_name].train()
self.latest_epoch = num_epoch self.latest_epoch = num_epoch
if self.dataset.train_dataset.curriculum_config: if self.dataset.train_dataset.curriculum_config:
...@@ -642,7 +641,7 @@ class GenericTrainingManager: ...@@ -642,7 +641,7 @@ class GenericTrainingManager:
if self.is_master: if self.is_master:
# log metrics in tensorboard file # log metrics in tensorboard file
for key in display_values.keys(): for key in display_values:
self.writer.add_scalar( self.writer.add_scalar(
"{}_{}".format(self.params["dataset"]["train"]["name"], key), "{}_{}".format(self.params["dataset"]["train"]["name"], key),
display_values[key], display_values[key],
...@@ -656,14 +655,14 @@ class GenericTrainingManager: ...@@ -656,14 +655,14 @@ class GenericTrainingManager:
% self.params["training"]["validation"]["eval_on_valid_interval"] % self.params["training"]["validation"]["eval_on_valid_interval"]
== 0 == 0
): ):
for valid_set_name in self.dataset.valid_loaders.keys(): for valid_set_name in self.dataset.valid_loaders:
# evaluate set and compute metrics # evaluate set and compute metrics
eval_values = self.evaluate( eval_values = self.evaluate(
valid_set_name, mlflow_logging=mlflow_logging valid_set_name, mlflow_logging=mlflow_logging
) )
# log valid metrics in tensorboard file # log valid metrics in tensorboard file
if self.is_master: if self.is_master:
for key in eval_values.keys(): for key in eval_values:
self.writer.add_scalar( self.writer.add_scalar(
"{}_{}".format(valid_set_name, key), "{}_{}".format(valid_set_name, key),
eval_values[key], eval_values[key],
...@@ -696,7 +695,7 @@ class GenericTrainingManager: ...@@ -696,7 +695,7 @@ class GenericTrainingManager:
""" """
loader = self.dataset.valid_loaders[set_name] loader = self.dataset.valid_loaders[set_name]
# Set models in eval mode # Set models in eval mode
for model_name in self.models.keys(): for model_name in self.models:
self.models[model_name].eval() self.models[model_name].eval()
metric_names = self.params["training"]["metrics"]["eval"] metric_names = self.params["training"]["metrics"]["eval"]
display_values = None display_values = None
...@@ -749,7 +748,7 @@ class GenericTrainingManager: ...@@ -749,7 +748,7 @@ class GenericTrainingManager:
self.dataset.generate_test_loader(custom_name, sets_list) self.dataset.generate_test_loader(custom_name, sets_list)
loader = self.dataset.test_loaders[custom_name] loader = self.dataset.test_loaders[custom_name]
# Set models in eval mode # Set models in eval mode
for model_name in self.models.keys(): for model_name in self.models:
self.models[model_name].eval() self.models[model_name].eval()
# initialize epoch metrics # initialize epoch metrics
...@@ -835,7 +834,7 @@ class GenericTrainingManager: ...@@ -835,7 +834,7 @@ class GenericTrainingManager:
""" """
Merge metrics when Distributed Data Parallel is used Merge metrics when Distributed Data Parallel is used
""" """
for metric_name in metrics.keys(): for metric_name in metrics:
if metric_name in [ if metric_name in [
"edit_words", "edit_words",
"nb_words", "nb_words",
...@@ -891,7 +890,7 @@ class GenericTrainingManager: ...@@ -891,7 +890,7 @@ class GenericTrainingManager:
""" """
Load curriculum info from saved model info Load curriculum info from saved model info
""" """
if "curriculum_config" in info_dict.keys(): if "curriculum_config" in info_dict:
self.dataset.train_dataset.curriculum_config = info_dict[ self.dataset.train_dataset.curriculum_config = info_dict[
"curriculum_config" "curriculum_config"
] ]
...@@ -915,7 +914,7 @@ class Manager(OCRManager): ...@@ -915,7 +914,7 @@ class Manager(OCRManager):
super(Manager, self).__init__(params) super(Manager, self).__init__(params)
def load_save_info(self, info_dict): def load_save_info(self, info_dict):
if "curriculum_config" in info_dict.keys(): if "curriculum_config" in info_dict:
if self.dataset.train_dataset is not None: if self.dataset.train_dataset is not None:
self.dataset.train_dataset.curriculum_config = info_dict[ self.dataset.train_dataset.curriculum_config = info_dict[
"curriculum_config" "curriculum_config"
......
...@@ -18,7 +18,7 @@ class DropoutScheduler: ...@@ -18,7 +18,7 @@ class DropoutScheduler:
self.step_num += num self.step_num += num
def init_teta_list(self, models): def init_teta_list(self, models):
for model_name in models.keys(): for model_name in models:
self.init_teta_list_module(models[model_name]) self.init_teta_list_module(models[model_name])
def init_teta_list_module(self, module): def init_teta_list_module(self, module):
......
...@@ -48,7 +48,7 @@ def train_and_test(rank, params, mlflow_logging=False): ...@@ -48,7 +48,7 @@ def train_and_test(rank, params, mlflow_logging=False):
model.load_model() model.load_model()
metrics = ["cer", "wer", "wer_no_punct", "time"] metrics = ["cer", "wer", "wer_no_punct", "time"]
for dataset_name in params["dataset"]["datasets"].keys(): for dataset_name in params["dataset"]["datasets"]:
for set_name in ["test", "val", "train"]: for set_name in ["test", "val", "train"]:
model.predict( model.predict(
"{}-{}".format(dataset_name, set_name), "{}-{}".format(dataset_name, set_name),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment