Skip to content
Snippets Groups Projects

Create properties to access encoder/decoder

Merged Manon Blanco requested to merge train-model-properties into main
All threads resolved!
1 file
+ 22
16
Compare changes
  • Side-by-side
  • Inline
+ 22
16
@@ -31,7 +31,9 @@ if MLFLOW_AVAILABLE:
import mlflow
logger = logging.getLogger(__name__)
MODEL_NAMES = ("encoder", "decoder")
MODEL_NAME_ENCODER = "encoder"
MODEL_NAME_DECODER = "decoder"
MODEL_NAMES = (MODEL_NAME_ENCODER, MODEL_NAME_DECODER)
class GenericTrainingManager:
@@ -69,6 +71,14 @@ class GenericTrainingManager:
self.init_paths()
self.load_dataset()
@property
def encoder(self):
return self.models.get(MODEL_NAME_ENCODER)
@property
def decoder(self):
return self.models.get(MODEL_NAME_DECODER)
def init_paths(self):
"""
Create output folders for results and checkpoints
@@ -985,20 +995,18 @@ class Manager(GenericTrainingManager):
hidden_predict = None
cache = None
features = self.models["encoder"](batch_data["imgs"].to(self.device))
features = self.encoder(batch_data["imgs"].to(self.device))
features_size = features.size()
if self.device_params["use_ddp"]:
features = self.models[
"decoder"
].module.features_updater.get_pos_features(features)
else:
features = self.models["decoder"].features_updater.get_pos_features(
features = self.decoder.module.features_updater.get_pos_features(
features
)
else:
features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
output, pred, hidden_predict, cache, weights = self.models["decoder"](
output, pred, hidden_predict, cache, weights = self.decoder(
features,
simulated_y_pred[:, :-1],
[s[:2] for s in batch_data["imgs_reduced_shape"]],
@@ -1058,7 +1066,7 @@ class Manager(GenericTrainingManager):
for i in range(b):
pos = batch_data["imgs_position"]
features_list.append(
self.models["encoder"](
self.encoder(
x[
i : i + 1,
:,
@@ -1079,21 +1087,19 @@ class Manager(GenericTrainingManager):
i, :, : features_list[i].size(2), : features_list[i].size(3)
] = features_list[i]
else:
features = self.models["encoder"](x)
features = self.encoder(x)
features_size = features.size()
if self.device_params["use_ddp"]:
features = self.models[
"decoder"
].module.features_updater.get_pos_features(features)
else:
features = self.models["decoder"].features_updater.get_pos_features(
features = self.decoder.module.features_updater.get_pos_features(
features
)
else:
features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
for i in range(0, max_chars):
output, pred, hidden_predict, cache, weights = self.models["decoder"](
output, pred, hidden_predict, cache, weights = self.decoder(
features,
predicted_tokens,
[s[:2] for s in batch_data["imgs_reduced_shape"]],
Loading