Skip to content
Snippets Groups Projects
Commit 389c3505 authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Merge branch 'train-model-properties' into 'main'

Create properties to access encoder/decoder

Closes #251

See merge request !346
parents 9aae0c16 6f2c5cb4
No related branches found
No related tags found
1 merge request!346Create properties to access encoder/decoder
......@@ -21,6 +21,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.manager.metrics import Inference, MetricManager
from dan.ocr.manager.ocr import OCRDatasetManager
from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
......@@ -31,7 +33,9 @@ if MLFLOW_AVAILABLE:
import mlflow
logger = logging.getLogger(__name__)
MODEL_NAMES = ("encoder", "decoder")
MODEL_NAME_ENCODER = "encoder"
MODEL_NAME_DECODER = "decoder"
MODEL_NAMES = (MODEL_NAME_ENCODER, MODEL_NAME_DECODER)
class GenericTrainingManager:
......@@ -69,6 +73,14 @@ class GenericTrainingManager:
self.init_paths()
self.load_dataset()
@property
def encoder(self) -> FCN_Encoder | None:
return self.models.get(MODEL_NAME_ENCODER)
@property
def decoder(self) -> GlobalHTADecoder | None:
return self.models.get(MODEL_NAME_DECODER)
def init_paths(self):
"""
Create output folders for results and checkpoints
......@@ -985,20 +997,18 @@ class Manager(GenericTrainingManager):
hidden_predict = None
cache = None
features = self.models["encoder"](batch_data["imgs"].to(self.device))
features = self.encoder(batch_data["imgs"].to(self.device))
features_size = features.size()
if self.device_params["use_ddp"]:
features = self.models[
"decoder"
].module.features_updater.get_pos_features(features)
else:
features = self.models["decoder"].features_updater.get_pos_features(
features = self.decoder.module.features_updater.get_pos_features(
features
)
else:
features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
output, pred, hidden_predict, cache, weights = self.models["decoder"](
output, pred, hidden_predict, cache, weights = self.decoder(
features,
simulated_y_pred[:, :-1],
[s[:2] for s in batch_data["imgs_reduced_shape"]],
......@@ -1058,7 +1068,7 @@ class Manager(GenericTrainingManager):
for i in range(b):
pos = batch_data["imgs_position"]
features_list.append(
self.models["encoder"](
self.encoder(
x[
i : i + 1,
:,
......@@ -1079,21 +1089,19 @@ class Manager(GenericTrainingManager):
i, :, : features_list[i].size(2), : features_list[i].size(3)
] = features_list[i]
else:
features = self.models["encoder"](x)
features = self.encoder(x)
features_size = features.size()
if self.device_params["use_ddp"]:
features = self.models[
"decoder"
].module.features_updater.get_pos_features(features)
else:
features = self.models["decoder"].features_updater.get_pos_features(
features = self.decoder.module.features_updater.get_pos_features(
features
)
else:
features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
for i in range(0, max_chars):
output, pred, hidden_predict, cache, weights = self.models["decoder"](
output, pred, hidden_predict, cache, weights = self.decoder(
features,
predicted_tokens,
[s[:2] for s in batch_data["imgs_reduced_shape"]],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment