diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py
index d436ead177c6f959fcc757c27722f970477df784..66f6c03bf217324e28dacde5d27c6bd91a444a42 100644
--- a/dan/ocr/manager/training.py
+++ b/dan/ocr/manager/training.py
@@ -21,6 +21,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.utils.tensorboard import SummaryWriter
 from tqdm import tqdm
 
+from dan.ocr.decoder import GlobalHTADecoder
+from dan.ocr.encoder import FCN_Encoder
 from dan.ocr.manager.metrics import Inference, MetricManager
 from dan.ocr.manager.ocr import OCRDatasetManager
 from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
@@ -31,7 +33,9 @@ if MLFLOW_AVAILABLE:
     import mlflow
 
 logger = logging.getLogger(__name__)
-MODEL_NAMES = ("encoder", "decoder")
+MODEL_NAME_ENCODER = "encoder"
+MODEL_NAME_DECODER = "decoder"
+MODEL_NAMES = (MODEL_NAME_ENCODER, MODEL_NAME_DECODER)
 
 
 class GenericTrainingManager:
@@ -69,6 +73,14 @@ class GenericTrainingManager:
         self.init_paths()
         self.load_dataset()
 
+    @property
+    def encoder(self) -> FCN_Encoder | None:
+        return self.models.get(MODEL_NAME_ENCODER)
+
+    @property
+    def decoder(self) -> GlobalHTADecoder | None:
+        return self.models.get(MODEL_NAME_DECODER)
+
     def init_paths(self):
         """
         Create output folders for results and checkpoints
@@ -985,20 +997,18 @@ class Manager(GenericTrainingManager):
             hidden_predict = None
             cache = None
 
-            features = self.models["encoder"](batch_data["imgs"].to(self.device))
+            features = self.encoder(batch_data["imgs"].to(self.device))
             features_size = features.size()
 
             if self.device_params["use_ddp"]:
-                features = self.models[
-                    "decoder"
-                ].module.features_updater.get_pos_features(features)
-            else:
-                features = self.models["decoder"].features_updater.get_pos_features(
+                features = self.decoder.module.features_updater.get_pos_features(
                     features
                 )
+            else:
+                features = self.decoder.features_updater.get_pos_features(features)
             features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
 
-            output, pred, hidden_predict, cache, weights = self.models["decoder"](
+            output, pred, hidden_predict, cache, weights = self.decoder(
                 features,
                 simulated_y_pred[:, :-1],
                 [s[:2] for s in batch_data["imgs_reduced_shape"]],
@@ -1058,7 +1068,7 @@ class Manager(GenericTrainingManager):
                 for i in range(b):
                     pos = batch_data["imgs_position"]
                     features_list.append(
-                        self.models["encoder"](
+                        self.encoder(
                             x[
                                 i : i + 1,
                                 :,
@@ -1079,21 +1089,19 @@ class Manager(GenericTrainingManager):
                         i, :, : features_list[i].size(2), : features_list[i].size(3)
                     ] = features_list[i]
             else:
-                features = self.models["encoder"](x)
+                features = self.encoder(x)
             features_size = features.size()
 
             if self.device_params["use_ddp"]:
-                features = self.models[
-                    "decoder"
-                ].module.features_updater.get_pos_features(features)
-            else:
-                features = self.models["decoder"].features_updater.get_pos_features(
+                features = self.decoder.module.features_updater.get_pos_features(
                     features
                 )
+            else:
+                features = self.decoder.features_updater.get_pos_features(features)
             features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
 
             for i in range(0, max_chars):
-                output, pred, hidden_predict, cache, weights = self.models["decoder"](
+                output, pred, hidden_predict, cache, weights = self.decoder(
                     features,
                     predicted_tokens,
                     [s[:2] for s in batch_data["imgs_reduced_shape"]],