Skip to content
Snippets Groups Projects
Commit 389c3505 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Merge branch 'train-model-properties' into 'main'

Create properties to access encoder/decoder

Closes #251

See merge request !346
parents 9aae0c16 6f2c5cb4
No related branches found
No related tags found
1 merge request!346Create properties to access encoder/decoder
...@@ -21,6 +21,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -21,6 +21,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.manager.metrics import Inference, MetricManager from dan.ocr.manager.metrics import Inference, MetricManager
from dan.ocr.manager.ocr import OCRDatasetManager from dan.ocr.manager.ocr import OCRDatasetManager
from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
...@@ -31,7 +33,9 @@ if MLFLOW_AVAILABLE: ...@@ -31,7 +33,9 @@ if MLFLOW_AVAILABLE:
import mlflow import mlflow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MODEL_NAMES = ("encoder", "decoder") MODEL_NAME_ENCODER = "encoder"
MODEL_NAME_DECODER = "decoder"
MODEL_NAMES = (MODEL_NAME_ENCODER, MODEL_NAME_DECODER)
class GenericTrainingManager: class GenericTrainingManager:
...@@ -69,6 +73,14 @@ class GenericTrainingManager: ...@@ -69,6 +73,14 @@ class GenericTrainingManager:
self.init_paths() self.init_paths()
self.load_dataset() self.load_dataset()
@property
def encoder(self) -> FCN_Encoder | None:
return self.models.get(MODEL_NAME_ENCODER)
@property
def decoder(self) -> GlobalHTADecoder | None:
return self.models.get(MODEL_NAME_DECODER)
def init_paths(self): def init_paths(self):
""" """
Create output folders for results and checkpoints Create output folders for results and checkpoints
...@@ -985,20 +997,18 @@ class Manager(GenericTrainingManager): ...@@ -985,20 +997,18 @@ class Manager(GenericTrainingManager):
hidden_predict = None hidden_predict = None
cache = None cache = None
features = self.models["encoder"](batch_data["imgs"].to(self.device)) features = self.encoder(batch_data["imgs"].to(self.device))
features_size = features.size() features_size = features.size()
if self.device_params["use_ddp"]: if self.device_params["use_ddp"]:
features = self.models[ features = self.decoder.module.features_updater.get_pos_features(
"decoder"
].module.features_updater.get_pos_features(features)
else:
features = self.models["decoder"].features_updater.get_pos_features(
features features
) )
else:
features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1) features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
output, pred, hidden_predict, cache, weights = self.models["decoder"]( output, pred, hidden_predict, cache, weights = self.decoder(
features, features,
simulated_y_pred[:, :-1], simulated_y_pred[:, :-1],
[s[:2] for s in batch_data["imgs_reduced_shape"]], [s[:2] for s in batch_data["imgs_reduced_shape"]],
...@@ -1058,7 +1068,7 @@ class Manager(GenericTrainingManager): ...@@ -1058,7 +1068,7 @@ class Manager(GenericTrainingManager):
for i in range(b): for i in range(b):
pos = batch_data["imgs_position"] pos = batch_data["imgs_position"]
features_list.append( features_list.append(
self.models["encoder"]( self.encoder(
x[ x[
i : i + 1, i : i + 1,
:, :,
...@@ -1079,21 +1089,19 @@ class Manager(GenericTrainingManager): ...@@ -1079,21 +1089,19 @@ class Manager(GenericTrainingManager):
i, :, : features_list[i].size(2), : features_list[i].size(3) i, :, : features_list[i].size(2), : features_list[i].size(3)
] = features_list[i] ] = features_list[i]
else: else:
features = self.models["encoder"](x) features = self.encoder(x)
features_size = features.size() features_size = features.size()
if self.device_params["use_ddp"]: if self.device_params["use_ddp"]:
features = self.models[ features = self.decoder.module.features_updater.get_pos_features(
"decoder"
].module.features_updater.get_pos_features(features)
else:
features = self.models["decoder"].features_updater.get_pos_features(
features features
) )
else:
features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1) features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
for i in range(0, max_chars): for i in range(0, max_chars):
output, pred, hidden_predict, cache, weights = self.models["decoder"]( output, pred, hidden_predict, cache, weights = self.decoder(
features, features,
predicted_tokens, predicted_tokens,
[s[:2] for s in batch_data["imgs_reduced_shape"]], [s[:2] for s in batch_data["imgs_reduced_shape"]],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment