Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (2)
......@@ -21,6 +21,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from dan.ocr.decoder import GlobalHTADecoder
from dan.ocr.encoder import FCN_Encoder
from dan.ocr.manager.metrics import Inference, MetricManager
from dan.ocr.manager.ocr import OCRDatasetManager
from dan.ocr.mlflow import MLFLOW_AVAILABLE, logging_metrics, logging_tags_metrics
......@@ -31,7 +33,9 @@ if MLFLOW_AVAILABLE:
import mlflow
logger = logging.getLogger(__name__)
MODEL_NAMES = ("encoder", "decoder")
MODEL_NAME_ENCODER = "encoder"
MODEL_NAME_DECODER = "decoder"
MODEL_NAMES = (MODEL_NAME_ENCODER, MODEL_NAME_DECODER)
class GenericTrainingManager:
......@@ -69,6 +73,14 @@ class GenericTrainingManager:
self.init_paths()
self.load_dataset()
@property
def encoder(self) -> FCN_Encoder | None:
return self.models.get(MODEL_NAME_ENCODER)
@property
def decoder(self) -> GlobalHTADecoder | None:
return self.models.get(MODEL_NAME_DECODER)
def init_paths(self):
"""
Create output folders for results and checkpoints
......@@ -985,20 +997,18 @@ class Manager(GenericTrainingManager):
hidden_predict = None
cache = None
features = self.models["encoder"](batch_data["imgs"].to(self.device))
features = self.encoder(batch_data["imgs"].to(self.device))
features_size = features.size()
if self.device_params["use_ddp"]:
features = self.models[
"decoder"
].module.features_updater.get_pos_features(features)
else:
features = self.models["decoder"].features_updater.get_pos_features(
features = self.decoder.module.features_updater.get_pos_features(
features
)
else:
features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
output, pred, hidden_predict, cache, weights = self.models["decoder"](
output, pred, hidden_predict, cache, weights = self.decoder(
features,
simulated_y_pred[:, :-1],
[s[:2] for s in batch_data["imgs_reduced_shape"]],
......@@ -1058,7 +1068,7 @@ class Manager(GenericTrainingManager):
for i in range(b):
pos = batch_data["imgs_position"]
features_list.append(
self.models["encoder"](
self.encoder(
x[
i : i + 1,
:,
......@@ -1079,21 +1089,19 @@ class Manager(GenericTrainingManager):
i, :, : features_list[i].size(2), : features_list[i].size(3)
] = features_list[i]
else:
features = self.models["encoder"](x)
features = self.encoder(x)
features_size = features.size()
if self.device_params["use_ddp"]:
features = self.models[
"decoder"
].module.features_updater.get_pos_features(features)
else:
features = self.models["decoder"].features_updater.get_pos_features(
features = self.decoder.module.features_updater.get_pos_features(
features
)
else:
features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1)
for i in range(0, max_chars):
output, pred, hidden_predict, cache, weights = self.models["decoder"](
output, pred, hidden_predict, cache, weights = self.decoder(
features,
predicted_tokens,
[s[:2] for s in batch_data["imgs_reduced_shape"]],
......