diff --git a/dan/ocr/manager/training.py b/dan/ocr/manager/training.py index d436ead177c6f959fcc757c27722f970477df784..66f6c03bf217324e28dacde5d27c6bd91a444a42 100644 --- a/dan/ocr/manager/training.py +++ b/dan/ocr/manager/training.py @@ -21,6 +21,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm +from dan.ocr.decoder import GlobalHTADecoder +from dan.ocr.encoder import FCN_Encoder from dan.ocr.manager.metrics import Inference, MetricManager from dan.ocr.manager.ocr import OCRDatasetManager from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics @@ -31,7 +33,9 @@ if MLFLOW_AVAILABLE: import mlflow logger = logging.getLogger(__name__) -MODEL_NAMES = ("encoder", "decoder") +MODEL_NAME_ENCODER = "encoder" +MODEL_NAME_DECODER = "decoder" +MODEL_NAMES = (MODEL_NAME_ENCODER, MODEL_NAME_DECODER) class GenericTrainingManager: @@ -69,6 +73,14 @@ class GenericTrainingManager: self.init_paths() self.load_dataset() + @property + def encoder(self) -> FCN_Encoder | None: + return self.models.get(MODEL_NAME_ENCODER) + + @property + def decoder(self) -> GlobalHTADecoder | None: + return self.models.get(MODEL_NAME_DECODER) + def init_paths(self): """ Create output folders for results and checkpoints @@ -985,20 +997,18 @@ class Manager(GenericTrainingManager): hidden_predict = None cache = None - features = self.models["encoder"](batch_data["imgs"].to(self.device)) + features = self.encoder(batch_data["imgs"].to(self.device)) features_size = features.size() if self.device_params["use_ddp"]: - features = self.models[ - "decoder" - ].module.features_updater.get_pos_features(features) - else: - features = self.models["decoder"].features_updater.get_pos_features( + features = self.decoder.module.features_updater.get_pos_features( features ) + else: + features = self.decoder.features_updater.get_pos_features(features) features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1) - output, pred, hidden_predict, cache, weights = self.models["decoder"]( + output, pred, hidden_predict, cache, weights = self.decoder( features, simulated_y_pred[:, :-1], [s[:2] for s in batch_data["imgs_reduced_shape"]], @@ -1058,7 +1068,7 @@ class Manager(GenericTrainingManager): for i in range(b): pos = batch_data["imgs_position"] features_list.append( - self.models["encoder"]( + self.encoder( x[ i : i + 1, :, @@ -1079,21 +1089,19 @@ class Manager(GenericTrainingManager): i, :, : features_list[i].size(2), : features_list[i].size(3) ] = features_list[i] else: - features = self.models["encoder"](x) + features = self.encoder(x) features_size = features.size() if self.device_params["use_ddp"]: - features = self.models[ - "decoder" - ].module.features_updater.get_pos_features(features) - else: - features = self.models["decoder"].features_updater.get_pos_features( + features = self.decoder.module.features_updater.get_pos_features( features ) + else: + features = self.decoder.features_updater.get_pos_features(features) features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1) for i in range(0, max_chars): - output, pred, hidden_predict, cache, weights = self.models["decoder"]( + output, pred, hidden_predict, cache, weights = self.decoder( features, predicted_tokens, [s[:2] for s in batch_data["imgs_reduced_shape"]],