From d144d0f62e12c57f2008488651ab4ff45e557a31 Mon Sep 17 00:00:00 2001 From: Adrian Hinrichs Date: Mon, 27 Jan 2025 08:02:02 +0100 Subject: [PATCH] Fix division by zero in Hits@k metric by checking for length Fixes #9936 --- torch_geometric/nn/kge/base.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/torch_geometric/nn/kge/base.py b/torch_geometric/nn/kge/base.py index 653f1a1b04a5..b79dc40ba272 100644 --- a/torch_geometric/nn/kge/base.py +++ b/torch_geometric/nn/kge/base.py @@ -18,6 +18,7 @@ class KGEModel(torch.nn.Module): sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the embedding matrices will be sparse. (default: :obj:`False`) """ + def __init__( self, num_nodes: int, @@ -123,15 +124,19 @@ def test( tail_indices = torch.arange(self.num_nodes, device=t.device) for ts in tail_indices.split(batch_size): scores.append(self(h.expand_as(ts), r.expand_as(ts), ts)) - rank = int((torch.cat(scores).argsort( - descending=True) == t).nonzero().view(-1)) + rank = int( + (torch.cat(scores).argsort(descending=True) == t).nonzero().view(-1) + ) mean_ranks.append(rank) reciprocal_ranks.append(1 / (rank + 1)) hits_at_k.append(rank < k) mean_rank = float(torch.tensor(mean_ranks, dtype=torch.float).mean()) mrr = float(torch.tensor(reciprocal_ranks, dtype=torch.float).mean()) - hits_at_k = int(torch.tensor(hits_at_k).sum()) / len(hits_at_k) + if len(hits_at_k) == 0: + hits_at_k = 0.0 + else: + hits_at_k = int(torch.tensor(hits_at_k).sum()) / len(hits_at_k) return mean_rank, mrr, hits_at_k @@ -152,8 +157,9 @@ def random_sample( """ # Random sample either `head_index` or `tail_index` (but not both): num_negatives = head_index.numel() // 2 - rnd_index = torch.randint(self.num_nodes, head_index.size(), - device=head_index.device) + rnd_index = torch.randint( + self.num_nodes, head_index.size(), device=head_index.device + ) head_index = head_index.clone() head_index[:num_negatives] = rnd_index[:num_negatives] @@ -163,6 +169,8 @@ def random_sample( return head_index, rel_type, tail_index def __repr__(self) -> str: - return (f'{self.__class__.__name__}({self.num_nodes}, ' - f'num_relations={self.num_relations}, ' - f'hidden_channels={self.hidden_channels})') + return ( + f"{self.__class__.__name__}({self.num_nodes}, " + f"num_relations={self.num_relations}, " + f"hidden_channels={self.hidden_channels})" + )