Skip to content

Commit

Permalink
Merge pull request #1916 from BishopLiu/master
Browse files Browse the repository at this point in the history
FEA: modify the expensive tensor multiplication in repeatnet
  • Loading branch information
BishopLiu authored Nov 15, 2023
2 parents ac60ad9 + c2a6759 commit 92e2f96
Showing 1 changed file with 11 additions and 48 deletions.
59 changes: 11 additions & 48 deletions recbole/model/sequential_recommender/repeatnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,12 @@ def forward(self, all_memory, last_memory, item_seq, mask=None):
output_er.masked_fill_(mask, -1e9)

output_er = nn.Softmax(dim=-1)(output_er)
output_er = output_er.unsqueeze(1)

map_matrix = build_map(item_seq, self.device, max_index=self.num_item)
output_er = torch.matmul(output_er, map_matrix).squeeze(1).to(self.device)
repeat_recommendation_decoder = output_er.squeeze(1).to(self.device)
batch_size, b_len = item_seq.size()
repeat_recommendation_decoder = torch.zeros(
[batch_size, self.num_item], device=self.device
)
repeat_recommendation_decoder.scatter_add_(1, item_seq, output_er)

return repeat_recommendation_decoder.to(self.device)

Expand Down Expand Up @@ -299,50 +300,12 @@ def forward(self, all_memory, last_memory, item_seq, mask=None):
output_e = torch.cat([output_e, last_memory_values], dim=1)
output_e = self.dropout(self.matrix_for_explore(output_e))

map_matrix = build_map(item_seq, self.device, max_index=self.num_item)
explore_mask = torch.bmm(
(item_seq > 0).float().unsqueeze(1), map_matrix
).squeeze(1)
output_e = output_e.masked_fill(explore_mask.bool(), float("-inf"))
item_seq_first = item_seq[:, 0].unsqueeze(1).expand_as(item_seq)
item_seq_first = item_seq_first.masked_fill(item_seq > 0, 0)
item_seq_first.requires_grad_(False)
output_e.scatter_add_(
1, item_seq + item_seq_first, float("-inf") * torch.ones_like(item_seq)
)
explore_recommendation_decoder = nn.Softmax(1)(output_e)

return explore_recommendation_decoder


def build_map(b_map, device, max_index=None):
"""
project the b_map to the place where it in should be like this:
item_seq A: [3,4,5] n_items: 6
after map: A
[0,0,1,0,0,0]
[0,0,0,1,0,0]
[0,0,0,0,1,0]
batch_size * seq_len ==>> batch_size * seq_len * n_item
use in RepeatNet:
[3,4,5] matmul [0,0,1,0,0,0]
[0,0,0,1,0,0]
[0,0,0,0,1,0]
==>>> [0,0,3,4,5,0] it works in the RepeatNet when project the seq item into all items
batch_size * 1 * seq_len matmul batch_size * seq_len * n_item ==>> batch_size * 1 * n_item
"""
batch_size, b_len = b_map.size()
if max_index is None:
max_index = b_map.max() + 1
if torch.cuda.is_available():
b_map_ = torch.FloatTensor(batch_size, b_len, max_index).fill_(0).to(device)
else:
b_map_ = torch.zeros(batch_size, b_len, max_index)
b_map_.scatter_(2, b_map.unsqueeze(2), 1.0)
b_map_.requires_grad = False
return b_map_

0 comments on commit 92e2f96

Please sign in to comment.