diff --git a/recbole/model/sequential_recommender/repeatnet.py b/recbole/model/sequential_recommender/repeatnet.py index e3506f2e7..cc7acf8d3 100644 --- a/recbole/model/sequential_recommender/repeatnet.py +++ b/recbole/model/sequential_recommender/repeatnet.py @@ -248,11 +248,12 @@ def forward(self, all_memory, last_memory, item_seq, mask=None): output_er.masked_fill_(mask, -1e9) output_er = nn.Softmax(dim=-1)(output_er) - output_er = output_er.unsqueeze(1) - map_matrix = build_map(item_seq, self.device, max_index=self.num_item) - output_er = torch.matmul(output_er, map_matrix).squeeze(1).to(self.device) - repeat_recommendation_decoder = output_er.squeeze(1).to(self.device) + batch_size, b_len = item_seq.size() + repeat_recommendation_decoder = torch.zeros( + [batch_size, self.num_item], device=self.device + ) + repeat_recommendation_decoder.scatter_add_(1, item_seq, output_er) return repeat_recommendation_decoder.to(self.device) @@ -299,50 +300,12 @@ def forward(self, all_memory, last_memory, item_seq, mask=None): output_e = torch.cat([output_e, last_memory_values], dim=1) output_e = self.dropout(self.matrix_for_explore(output_e)) - map_matrix = build_map(item_seq, self.device, max_index=self.num_item) - explore_mask = torch.bmm( - (item_seq > 0).float().unsqueeze(1), map_matrix - ).squeeze(1) - output_e = output_e.masked_fill(explore_mask.bool(), float("-inf")) + item_seq_first = item_seq[:, 0].unsqueeze(1).expand_as(item_seq) + item_seq_first = item_seq_first.masked_fill(item_seq > 0, 0) + item_seq_first.requires_grad_(False) + output_e.scatter_add_( + 1, item_seq + item_seq_first, float("-inf") * torch.ones_like(item_seq) + ) explore_recommendation_decoder = nn.Softmax(1)(output_e) return explore_recommendation_decoder - - -def build_map(b_map, device, max_index=None): - """ - project the b_map to the place where it in should be like this: - item_seq A: [3,4,5] n_items: 6 - - after map: A - - [0,0,1,0,0,0] - - [0,0,0,1,0,0] - - [0,0,0,0,1,0] - - batch_size * seq_len ==>> batch_size * seq_len * n_item - - use in RepeatNet: - - [3,4,5] matmul [0,0,1,0,0,0] - - [0,0,0,1,0,0] - - [0,0,0,0,1,0] - - ==>>> [0,0,3,4,5,0] it works in the RepeatNet when project the seq item into all items - - batch_size * 1 * seq_len matmul batch_size * seq_len * n_item ==>> batch_size * 1 * n_item - """ - batch_size, b_len = b_map.size() - if max_index is None: - max_index = b_map.max() + 1 - if torch.cuda.is_available(): - b_map_ = torch.FloatTensor(batch_size, b_len, max_index).fill_(0).to(device) - else: - b_map_ = torch.zeros(batch_size, b_len, max_index) - b_map_.scatter_(2, b_map.unsqueeze(2), 1.0) - b_map_.requires_grad = False - return b_map_