Skip to content

Commit

Permalink
Fix division by zero in Hits@k metric by checking for length
Browse files Browse the repository at this point in the history
  • Loading branch information
ACHinrichs committed Jan 27, 2025
1 parent 9bffcd4 commit d144d0f
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions torch_geometric/nn/kge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class KGEModel(torch.nn.Module):
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the
embedding matrices will be sparse. (default: :obj:`False`)
"""

def __init__(
self,
num_nodes: int,
Expand Down Expand Up @@ -123,15 +124,19 @@ def test(
tail_indices = torch.arange(self.num_nodes, device=t.device)
for ts in tail_indices.split(batch_size):
scores.append(self(h.expand_as(ts), r.expand_as(ts), ts))
rank = int((torch.cat(scores).argsort(
descending=True) == t).nonzero().view(-1))
rank = int(
(torch.cat(scores).argsort(descending=True) == t).nonzero().view(-1)
)
mean_ranks.append(rank)
reciprocal_ranks.append(1 / (rank + 1))
hits_at_k.append(rank < k)

mean_rank = float(torch.tensor(mean_ranks, dtype=torch.float).mean())
mrr = float(torch.tensor(reciprocal_ranks, dtype=torch.float).mean())
hits_at_k = int(torch.tensor(hits_at_k).sum()) / len(hits_at_k)
if len(hits_at_k) == 0:
hits_at_k = 0.0
else:
hits_at_k = int(torch.tensor(hits_at_k).sum()) / len(hits_at_k)

return mean_rank, mrr, hits_at_k

Expand All @@ -152,8 +157,9 @@ def random_sample(
"""
# Random sample either `head_index` or `tail_index` (but not both):
num_negatives = head_index.numel() // 2
rnd_index = torch.randint(self.num_nodes, head_index.size(),
device=head_index.device)
rnd_index = torch.randint(
self.num_nodes, head_index.size(), device=head_index.device
)

head_index = head_index.clone()
head_index[:num_negatives] = rnd_index[:num_negatives]
Expand All @@ -163,6 +169,8 @@ def random_sample(
return head_index, rel_type, tail_index

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.num_nodes}, '
f'num_relations={self.num_relations}, '
f'hidden_channels={self.hidden_channels})')
return (
f"{self.__class__.__name__}({self.num_nodes}, "
f"num_relations={self.num_relations}, "
f"hidden_channels={self.hidden_channels})"
)

0 comments on commit d144d0f

Please sign in to comment.