Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (4)
......@@ -40,6 +40,9 @@ class EntityType(NamedTuple):
def download_image(element: Element, im_path: Path):
if im_path.exists():
return im_path
tries = 1
# retry loop
while True:
......
......@@ -18,8 +18,6 @@ from torch.nn.init import xavier_uniform_
class PositionalEncoding1D(Module):
def __init__(self, dim, len_max, device):
super(PositionalEncoding1D, self).__init__()
self.len_max = len_max
self.dim = dim
self.pe = torch.zeros((1, dim, len_max), device=device, requires_grad=False)
div = torch.exp(
......@@ -46,9 +44,6 @@ class PositionalEncoding1D(Module):
class PositionalEncoding2D(Module):
def __init__(self, dim, h_max, w_max, device):
super(PositionalEncoding2D, self).__init__()
self.h_max = h_max
self.max_w = w_max
self.dim = dim
self.pe = torch.zeros(
(1, dim, h_max, w_max), device=device, requires_grad=False
)
......@@ -177,31 +172,28 @@ class GlobalDecoderLayer(Module):
def __init__(self, params):
super(GlobalDecoderLayer, self).__init__()
self.emb_dim = params["enc_dim"]
self.dim_feedforward = params["dec_dim_feedforward"]
self.self_att = CustomMultiHeadAttention(
embed_dim=self.emb_dim,
embed_dim=params["enc_dim"],
num_heads=params["dec_num_heads"],
proj_value=True,
dropout=params["dec_att_dropout"],
)
self.norm1 = LayerNorm(self.emb_dim)
self.norm1 = LayerNorm(params["enc_dim"])
self.att = CustomMultiHeadAttention(
embed_dim=self.emb_dim,
embed_dim=params["enc_dim"],
num_heads=params["dec_num_heads"],
proj_value=True,
dropout=params["dec_att_dropout"],
)
self.linear1 = Linear(self.emb_dim, self.dim_feedforward)
self.linear2 = Linear(self.dim_feedforward, self.emb_dim)
self.linear1 = Linear(params["enc_dim"], params["dec_dim_feedforward"])
self.linear2 = Linear(params["dec_dim_feedforward"], params["enc_dim"])
self.dropout = Dropout(params["dec_res_dropout"])
self.norm2 = LayerNorm(self.emb_dim)
self.norm3 = LayerNorm(self.emb_dim)
self.norm2 = LayerNorm(params["enc_dim"])
self.norm3 = LayerNorm(params["enc_dim"])
def forward(
self,
......@@ -319,11 +311,8 @@ class FeaturesUpdater(Module):
def __init__(self, params):
super(FeaturesUpdater, self).__init__()
self.enc_dim = params["enc_dim"]
self.enc_h_max = params["h_max"]
self.enc_w_max = params["w_max"]
self.pe_2d = PositionalEncoding2D(
self.enc_dim, self.enc_h_max, self.enc_w_max, params["device"]
params["enc_dim"], params["h_max"], params["w_max"], params["device"]
)
self.use_2d_positional_encoding = (
"use_2d_pe" not in params or params["use_2d_pe"]
......@@ -342,9 +331,6 @@ class GlobalHTADecoder(Module):
def __init__(self, params):
super(GlobalHTADecoder, self).__init__()
self.enc_dim = params["enc_dim"]
self.dec_l_max = params["l_max"]
self.dropout = Dropout(params["dec_pred_dropout"])
self.dec_att_win = (
params["attention_win"] if params["attention_win"] is not None else 1
......@@ -356,17 +342,17 @@ class GlobalHTADecoder(Module):
self.att_decoder = GlobalAttDecoder(params)
self.emb = Embedding(
num_embeddings=params["vocab_size"] + 3, embedding_dim=self.enc_dim
num_embeddings=params["vocab_size"] + 3, embedding_dim=params["enc_dim"]
)
self.pe_1d = PositionalEncoding1D(
self.enc_dim, self.dec_l_max, params["device"]
params["enc_dim"], params["l_max"], params["device"]
)
if self.use_lstm:
self.lstm_predict = LSTM(self.enc_dim, self.enc_dim)
self.lstm_predict = LSTM(params["enc_dim"], params["enc_dim"])
vocab_size = params["vocab_size"] + 1
self.end_conv = Conv1d(self.enc_dim, vocab_size, kernel_size=1)
self.end_conv = Conv1d(params["enc_dim"], vocab_size, kernel_size=1)
def forward(
self,
......
......@@ -18,6 +18,7 @@ def add_predict_parser(subcommands) -> None:
image_or_folder_input = parser.add_mutually_exclusive_group(required=True)
image_or_folder_input.add_argument(
"--image",
type=pathlib.Path,
help="Path to the image to predict.",
)
image_or_folder_input.add_argument(
......@@ -76,7 +77,7 @@ def add_predict_parser(subcommands) -> None:
type=float,
default=1.0,
help="Temperature scaling scalar parameter",
required=True,
required=False,
)
parser.add_argument(
"--confidence-score",
......