Skip to content

Commit

Permalink
Merge pull request #2071 from HotBento/patch-2
Browse files Browse the repository at this point in the history
fix: only use .cpu() in the final step in collector.py to reduce the usage of cpu
  • Loading branch information
TayTroye authored Aug 29, 2024
2 parents d64724a + f77c3fe commit bfe05d0
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions recbole/evaluator/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ def set(self, name: str, value):

def update_tensor(self, name: str, value: torch.Tensor):
if name not in self._data_dict:
self._data_dict[name] = value.cpu().clone().detach()
self._data_dict[name] = value.clone().detach()
else:
if not isinstance(self._data_dict[name], torch.Tensor):
raise ValueError("{} is not a tensor.".format(name))
self._data_dict[name] = torch.cat(
(self._data_dict[name], value.cpu().clone().detach()), dim=0
(self._data_dict[name], value.clone().detach()), dim=0
)

def __str__(self):
Expand Down Expand Up @@ -149,13 +149,15 @@ def eval_batch_collect(
positive_i(Torch.Tensor): the positive item id for each user.
"""
if self.register.need("rec.items"):

# get topk
_, topk_idx = torch.topk(
scores_tensor, max(self.topk), dim=-1
) # n_users x k
self.data_struct.update_tensor("rec.items", topk_idx)

if self.register.need("rec.topk"):

_, topk_idx = torch.topk(
scores_tensor, max(self.topk), dim=-1
) # n_users x k
Expand All @@ -167,6 +169,7 @@ def eval_batch_collect(
self.data_struct.update_tensor("rec.topk", result)

if self.register.need("rec.meanrank"):

desc_scores, desc_index = torch.sort(scores_tensor, dim=-1, descending=True)

# get the index of positive items in the ranking list
Expand All @@ -185,6 +188,7 @@ def eval_batch_collect(
self.data_struct.update_tensor("rec.meanrank", result)

if self.register.need("rec.score"):

self.data_struct.update_tensor("rec.score", scores_tensor)

if self.register.need("data.label"):
Expand Down Expand Up @@ -219,6 +223,8 @@ def get_data_struct(self):
"""Get all the evaluation resource that been collected.
And reset some of outdated resource.
"""
for key in self.data_struct._data_dict:
self.data_struct._data_dict[key] = self.data_struct._data_dict[key].cpu()
returned_struct = copy.deepcopy(self.data_struct)
for key in ["rec.topk", "rec.meanrank", "rec.score", "rec.items", "data.label"]:
if key in self.data_struct:
Expand Down

0 comments on commit bfe05d0

Please sign in to comment.