You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I think If you want the additive attention be able to deal with batch, while inputs are like these
Inputs: query, value
- query (batch_size, q_len, hidden_dim): tensor containing the output features from the decoder.
- value (batch_size, v_len, hidden_dim): tensor containing features of the encoded input sequence
the code in forward function should be like this:
def forward(self, query: Tensor, key: Tensor, value: Tensor):
score = self.score_proj(
torch.tanh(self.key_proj(key.unsqueeze(1)) + self.query_proj(query.unsqueeze(2)) + self.bias)).squeeze()
attn = F.softmax(score, dim=-1)
context = torch.bmm(attn, value)
return context, attn
otherwise, the size of self.key_proj(key.unsqueeze(1)) and self.query_proj(query.unsqueeze(2) will be dismatch on second dimension and can not be added
The text was updated successfully, but these errors were encountered:
Hello, I think If you want the additive attention be able to deal with batch, while inputs are like these
Inputs: query, value
- query (batch_size, q_len, hidden_dim): tensor containing the output features from the decoder.
- value (batch_size, v_len, hidden_dim): tensor containing features of the encoded input sequence
the code in forward function should be like this:
def forward(self, query: Tensor, key: Tensor, value: Tensor):
score = self.score_proj(
torch.tanh(self.key_proj(key.unsqueeze(1)) + self.query_proj(query.unsqueeze(2)) + self.bias)).squeeze()
attn = F.softmax(score, dim=-1)
context = torch.bmm(attn, value)
return context, attn
otherwise, the size of self.key_proj(key.unsqueeze(1)) and self.query_proj(query.unsqueeze(2) will be dismatch on second dimension and can not be added
The text was updated successfully, but these errors were encountered: