Skip to content

Commit

Permalink
fully implemented all function in nmt_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdelrahmanAbouelenin committed May 12, 2020
1 parent e2f53d2 commit 1c29e0c
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion assignment 4/nmt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,18 @@ def step(self, Ybar_t: torch.Tensor,

### END YOUR CODE


dec_state = self.decoder(Ybar_t,dec_state)
dec_hidden, dec_cell = dec_state
dec_hidden = torch.unsqueeze(dec_hidden,dim=1)

enc_hiddens_proj = torch.transpose(enc_hiddens_proj,1,2)

e_t = torch.bmm(dec_hidden,enc_hiddens_proj)
e_t = torch.squeeze(e_t,dim=1)



# Set e_t to -inf where enc_masks has 1
if enc_masks is not None:
e_t.data.masked_fill_(enc_masks.bool(), -float('inf'))
Expand Down Expand Up @@ -366,8 +378,22 @@ def step(self, Ybar_t: torch.Tensor,
### https://pytorch.org/docs/stable/torch.html#torch.cat
### Tanh:
### https://pytorch.org/docs/stable/torch.html#torch.tanh

alpha_t = nn.functional.softmax(e_t,dim=1)

alpha_t = torch.unsqueeze(alpha_t,dim=1)


a_t = torch.bmm(alpha_t,enc_hiddens)

a_t = torch.squeeze(a_t,1)


dec_hidden = torch.squeeze(dec_hidden,1)
U_t = torch.cat((a_t,dec_hidden),dim=1)
V_t= self.combined_output_projection(U_t)
O_t = nn.functional.tanh(V_t)
O_t = self.dropout(O_t)

### END YOUR CODE

combined_output = O_t
Expand Down

0 comments on commit 1c29e0c

Please sign in to comment.