diff --git a/assignment 4/nmt_model.py b/assignment 4/nmt_model.py index 0e4b7c5..dcebbfc 100644 --- a/assignment 4/nmt_model.py +++ b/assignment 4/nmt_model.py @@ -336,6 +336,18 @@ def step(self, Ybar_t: torch.Tensor, ### END YOUR CODE + + dec_state = self.decoder(Ybar_t,dec_state) + dec_hidden, dec_cell = dec_state + dec_hidden = torch.unsqueeze(dec_hidden,dim=1) + + enc_hiddens_proj = torch.transpose(enc_hiddens_proj,1,2) + + e_t = torch.bmm(dec_hidden,enc_hiddens_proj) + e_t = torch.squeeze(e_t,dim=1) + + + # Set e_t to -inf where enc_masks has 1 if enc_masks is not None: e_t.data.masked_fill_(enc_masks.bool(), -float('inf')) @@ -366,8 +378,22 @@ def step(self, Ybar_t: torch.Tensor, ### https://pytorch.org/docs/stable/torch.html#torch.cat ### Tanh: ### https://pytorch.org/docs/stable/torch.html#torch.tanh + + alpha_t = nn.functional.softmax(e_t,dim=1) + + alpha_t = torch.unsqueeze(alpha_t,dim=1) - + a_t = torch.bmm(alpha_t,enc_hiddens) + + a_t = torch.squeeze(a_t,1) + + + dec_hidden = torch.squeeze(dec_hidden,1) + U_t = torch.cat((a_t,dec_hidden),dim=1) + V_t= self.combined_output_projection(U_t) + O_t = nn.functional.tanh(V_t) + O_t = self.dropout(O_t) + ### END YOUR CODE combined_output = O_t