Skip to content
Snippets Groups Projects
Commit 33958fdb authored by Denis Coquenet's avatar Denis Coquenet
Browse files

add lstm option

parent 09f1d041
No related branches found
No related tags found
No related merge requests found
......@@ -151,6 +151,7 @@ if __name__ == "__main__":
"dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers
"use_2d_pe": True, # use 2D positional embedding
"use_1d_pe": True, # use 1D positional embedding
"use_lstm": False,
"attention_win": 100, # length of attention window
# Curriculum dropout
"dropout_scheduler": {
......
......@@ -262,22 +262,24 @@ class GlobalHTADecoder(Module):
self.enc_dim = params["enc_dim"]
self.dec_l_max = params["l_max"]
self.features_updater = FeaturesUpdater(params)
self.dropout = Dropout(params["dec_pred_dropout"])
self.dec_att_win = params["attention_win"] if params["attention_win"] is not None else 1
self.use_1d_pe = "use_1d_pe" not in params or params["use_1d_pe"]
self.use_lstm = params["use_lstm"]
self.features_updater = FeaturesUpdater(params)
self.att_decoder = GlobalAttDecoder(params)
self.emb = Embedding(num_embeddings=params["vocab_size"]+3, embedding_dim=self.enc_dim)
self.pe_1d = PositionalEncoding1D(self.enc_dim, self.dec_l_max, params["device"])
self.use_1d_pe = "use_1d_pe" not in params or params["use_1d_pe"]
vocab_size = params["vocab_size"] + 1
if self.use_lstm:
self.lstm_predict = LSTM(self.enc_dim, self.enc_dim)
vocab_size = params["vocab_size"] + 1
self.end_conv = Conv1d(self.enc_dim, vocab_size, kernel_size=1)
def forward(self, raw_features_1d, enhanced_features_1d, tokens, reduced_size, token_len, features_size, start=0, hidden_emb=None, hidden_predict=None, cache=None, num_pred=None, keep_all_weights=False, token_line=None, token_pg=None):
def forward(self, raw_features_1d, enhanced_features_1d, tokens, reduced_size, token_len, features_size, start=0, hidden_predict=None, cache=None, num_pred=None, keep_all_weights=False, token_line=None, token_pg=None):
device = raw_features_1d.device
# Token to Embedding
......@@ -321,12 +323,15 @@ class GlobalHTADecoder(Module):
predict_last_n_only=num_pred,
keep_all_weights=keep_all_weights)
if self.use_lstm:
output, hidden_predict = self.lstm_predict(output, hidden_predict)
dp_output = self.dropout(relu(output))
preds = self.end_conv(dp_output.permute(1, 2, 0))
if not keep_all_weights:
weights = torch.sum(weights, dim=1, keepdim=True).reshape(-1, 1, features_size[2], features_size[3])
return output, preds, hidden_emb, hidden_predict, cache, weights
return output, preds, hidden_predict, cache, weights
def generate_enc_mask(self, batch_reduced_size, total_size, device):
"""
......
......@@ -54,7 +54,6 @@ class Manager(OCRManager):
simulated_y_pred = y
with autocast(enabled=self.params["training_params"]["use_amp"]):
hidden_emb = None
hidden_predict = None
cache = None
......@@ -66,12 +65,12 @@ class Manager(OCRManager):
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(2, 0, 1)
enhanced_features = pos_features
enhanced_features = torch.flatten(enhanced_features, start_dim=2, end_dim=3).permute(2, 0, 1)
output, pred, hidden_emb, hidden_predict, cache, weights = self.models["decoder"](features, enhanced_features,
output, pred, hidden_predict, cache, weights = self.models["decoder"](features, enhanced_features,
simulated_y_pred[:, :-1],
reduced_size,
[max(y_len) for _ in range(b)],
features_size,
start=0, hidden_emb=hidden_emb,
start=0,
hidden_predict=hidden_predict,
cache=cache,
keep_all_weights=True)
......@@ -115,7 +114,6 @@ class Manager(OCRManager):
confidence_scores = list()
cache = None
hidden_predict = None
hidden_emb = None
if b > 1:
features_list = list()
for i in range(b):
......@@ -136,7 +134,7 @@ class Manager(OCRManager):
enhanced_features = torch.flatten(enhanced_features, start_dim=2, end_dim=3).permute(2, 0, 1)
for i in range(0, max_chars):
output, pred, hidden_emb, hidden_predict, cache, weights = self.models["decoder"](features, enhanced_features, predicted_tokens, reduced_size, predicted_tokens_len, features_size, start=0, hidden_emb=hidden_emb, hidden_predict=hidden_predict, cache=cache, num_pred=1)
output, pred, hidden_predict, cache, weights = self.models["decoder"](features, enhanced_features, predicted_tokens, reduced_size, predicted_tokens_len, features_size, start=0, hidden_predict=hidden_predict, cache=cache, num_pred=1)
whole_output.append(output)
confidence_scores.append(torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values)
coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment