diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index d436ead177c6f959fcc757c27722f970477df784..083a40f4fc6a2675fb2033a342190973056720cd 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -31,7 +31,9 @@ if MLFLOW_AVAILABLE: import mlflow logger = logging.getLogger(__name__) -MODEL_NAMES = ("encoder", "decoder") +MODEL_NAME_ENCODER = "encoder" +MODEL_NAME_DECODER = "decoder" +MODEL_NAMES = (MODEL_NAME_ENCODER, MODEL_NAME_DECODER) class GenericTrainingManager: @@ -69,6 +71,14 @@ class GenericTrainingManager: self.init_paths() self.load_dataset() + @property + def encoder(self): + return self.models.get(MODEL_NAME_ENCODER) + + @property + def decoder(self): + return self.models.get(MODEL_NAME_DECODER) + def init_paths(self): """ Create output folders for results and checkpoints @@ -985,20 +995,18 @@ class Manager(GenericTrainingManager): hidden_predict = None cache = None - features = self.models["encoder"](batch_data["imgs"].to(self.device)) + features = self.encoder(batch_data["imgs"].to(self.device)) features_size = features.size() if self.device_params["use_ddp"]: - features = self.models[ - "decoder" - ].module.features_updater.get_pos_features(features) - else: - features = self.models["decoder"].features_updater.get_pos_features( + features = self.decoder.module.features_updater.get_pos_features( features ) + else: + features = self.decoder.features_updater.get_pos_features(features) features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1) - output, pred, hidden_predict, cache, weights = self.models["decoder"]( + output, pred, hidden_predict, cache, weights = self.decoder( features, simulated_y_pred[:, :-1], [s[:2] for s in batch_data["imgs_reduced_shape"]], @@ -1058,7 +1066,7 @@ class Manager(GenericTrainingManager): for i in range(b): pos = batch_data["imgs_position"] features_list.append( - self.models["encoder"]( + self.encoder( x[ i : i + 1, :, @@ -1079,21 +1087,19 @@ class Manager(GenericTrainingManager): i, :, : features_list[i].size(2), : features_list[i].size(3) ] = features_list[i] else: - features = self.models["encoder"](x) + features = self.encoder(x) features_size = features.size() if self.device_params["use_ddp"]: - features = self.models[ - "decoder" - ].module.features_updater.get_pos_features(features) - else: - features = self.models["decoder"].features_updater.get_pos_features( + features = self.decoder.module.features_updater.get_pos_features( features ) + else: + features = self.decoder.features_updater.get_pos_features(features) features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1) for i in range(0, max_chars): - output, pred, hidden_predict, cache, weights = self.models["decoder"]( + output, pred, hidden_predict, cache, weights = self.decoder( features, predicted_tokens, [s[:2] for s in batch_data["imgs_reduced_shape"]],