Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (2)
...@@ -21,6 +21,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -21,6 +21,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.manager.metrics import Inference, MetricManager from dan.ocr.manager.metrics import Inference, MetricManager
from dan.ocr.manager.ocr import OCRDatasetManager from dan.ocr.manager.ocr import OCRDatasetManager
from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
...@@ -31,7 +33,9 @@ if MLFLOW_AVAILABLE: ...@@ -31,7 +33,9 @@ if MLFLOW_AVAILABLE:
import mlflow import mlflow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MODEL_NAMES = ("encoder", "decoder") MODEL_NAME_ENCODER = "encoder"
MODEL_NAME_DECODER = "decoder"
MODEL_NAMES = (MODEL_NAME_ENCODER, MODEL_NAME_DECODER)
class GenericTrainingManager: class GenericTrainingManager:
...@@ -69,6 +73,14 @@ class GenericTrainingManager: ...@@ -69,6 +73,14 @@ class GenericTrainingManager:
self.init_paths() self.init_paths()
self.load_dataset() self.load_dataset()
@property
def encoder(self) -> FCN_Encoder | None:
return self.models.get(MODEL_NAME_ENCODER)
@property
def decoder(self) -> GlobalHTADecoder | None:
return self.models.get(MODEL_NAME_DECODER)
def init_paths(self): def init_paths(self):
""" """
Create output folders for results and checkpoints Create output folders for results and checkpoints
...@@ -985,20 +997,18 @@ class Manager(GenericTrainingManager): ...@@ -985,20 +997,18 @@ class Manager(GenericTrainingManager):
hidden_predict = None hidden_predict = None
cache = None cache = None
features = self.models["encoder"](batch_data["imgs"].to(self.device)) features = self.encoder(batch_data["imgs"].to(self.device))
features_size = features.size() features_size = features.size()
if self.device_params["use_ddp"]: if self.device_params["use_ddp"]:
features = self.models[ features = self.decoder.module.features_updater.get_pos_features(
"decoder"
].module.features_updater.get_pos_features(features)
else:
features = self.models["decoder"].features_updater.get_pos_features(
features features
) )
else:
features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1) features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
output, pred, hidden_predict, cache, weights = self.models["decoder"]( output, pred, hidden_predict, cache, weights = self.decoder(
features, features,
simulated_y_pred[:, :-1], simulated_y_pred[:, :-1],
[s[:2] for s in batch_data["imgs_reduced_shape"]], [s[:2] for s in batch_data["imgs_reduced_shape"]],
...@@ -1058,7 +1068,7 @@ class Manager(GenericTrainingManager): ...@@ -1058,7 +1068,7 @@ class Manager(GenericTrainingManager):
for i in range(b): for i in range(b):
pos = batch_data["imgs_position"] pos = batch_data["imgs_position"]
features_list.append( features_list.append(
self.models["encoder"]( self.encoder(
x[ x[
i : i + 1, i : i + 1,
:, :,
...@@ -1079,21 +1089,19 @@ class Manager(GenericTrainingManager): ...@@ -1079,21 +1089,19 @@ class Manager(GenericTrainingManager):
i, :, : features_list[i].size(2), : features_list[i].size(3) i, :, : features_list[i].size(2), : features_list[i].size(3)
] = features_list[i] ] = features_list[i]
else: else:
features = self.models["encoder"](x) features = self.encoder(x)
features_size = features.size() features_size = features.size()
if self.device_params["use_ddp"]: if self.device_params["use_ddp"]:
features = self.models[ features = self.decoder.module.features_updater.get_pos_features(
"decoder"
].module.features_updater.get_pos_features(features)
else:
features = self.models["decoder"].features_updater.get_pos_features(
features features
) )
else:
features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1) features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
for i in range(0, max_chars): for i in range(0, max_chars):
output, pred, hidden_predict, cache, weights = self.models["decoder"]( output, pred, hidden_predict, cache, weights = self.decoder(
features, features,
predicted_tokens, predicted_tokens,
[s[:2] for s in batch_data["imgs_reduced_shape"]], [s[:2] for s in batch_data["imgs_reduced_shape"]],
......