Skip to content
Snippets Groups Projects
Commit d87bc70f authored by Marie Generali's avatar Marie Generali :worried:
Browse files

fix tests

parent 3de1f1b6
No related branches found
No related tags found
No related merge requests found
......@@ -1036,7 +1036,7 @@ class Manager(OCRManager):
enhanced_features = torch.flatten(
enhanced_features, start_dim=2, end_dim=3
).permute(2, 0, 1)
output, pred, hidden_predict, cache, weights = self.models["decoder"](
output, pred, hidden_predict, cache, weights, temperature = self.models["decoder"](
features,
enhanced_features,
simulated_y_pred[:, :-1],
......@@ -1133,7 +1133,7 @@ class Manager(OCRManager):
).permute(2, 0, 1)
for i in range(0, max_chars):
output, pred, hidden_predict, cache, weights = self.models["decoder"](
output, pred, hidden_predict, cache, weights, temperature = self.models["decoder"](
features,
enhanced_features,
predicted_tokens,
......
......@@ -167,6 +167,7 @@ def get_config():
"use_2d_pe": True, # use 2D positional embedding
"use_1d_pe": True, # use 1D positional embedding
"use_lstm": False,
"temperature": 1, # temperature scaling scalar parameter
"attention_win": 100, # length of attention window
# Curriculum dropout
"dropout_scheduler": {
......
......@@ -107,6 +107,7 @@ def training_config():
"dec_pred_dropout": 0.1, # dropout rate before decision layer
"dec_att_dropout": 0.1, # dropout rate in multi head attention
"dec_dim_feedforward": 256, # number of dimension for feedforward layer in transformer decoder layers
"temperature": 1, #temperature scaling scalar parameter
"use_2d_pe": True, # use 2D positional embedding
"use_1d_pe": True, # use 1D positional embedding
"use_lstm": False,
......
......@@ -22,3 +22,4 @@ parameters:
dec_num_heads: 4
dec_att_dropout: 0.1
dec_res_dropout: 0.1
temperature: 1
......@@ -79,6 +79,8 @@ def test_train_and_test(
expected_param,
expected_tensor,
) in zip(trained.items(), expected.items()):
print(f'trained tensor is {trained_tensor}')
print(f'expected tensor os {expected_tensor}')
assert trained_param == expected_param
assert torch.allclose(trained_tensor, expected_tensor, atol=1e-03)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment