You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If we use triplet mode Negative Sampling in LinkNeighborLoader, LinkNeighborLoader does not return the edge label corresponding to dst_pos_index and dst_neg_index. I want the edge label.
The following code is the example.
importtorchfromtorch_geometric.dataimportDatafromtorch_geometric.loaderimportLinkNeighborLoaderfromtorch_geometric.samplerimportNegativeSampling# dummy datax=torch.randn(6, 10)
edge_index=torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]], dtype=torch.long)
edge_label_index=torch.tensor([[2, 3], [0, 0]], dtype=torch.long)
edge_type=torch.tensor([0, 1, 0, 1, 0], dtype=torch.long)
edge_label=torch.tensor([1, 2], dtype=torch.long)
data=Data(
x=x, edge_index=edge_index, edge_label_index=edge_label_index, edge_label=edge_label
)
loader=LinkNeighborLoader(
data,
edge_label_index=edge_label_index,
edge_label=None,
batch_size=1,
shuffle=True,
num_neighbors=[-1],
neg_sampling=NegativeSampling(mode="triplet", amount=10),
)
batch=next(iter(loader))
batch.edge_label# all edge labels, not just the ones sampled by the loader
I expected the following attr.
batch.dst_pos_label: size = len(dst_pos_index), and each value is edge_label corresponding to dst_pos_index, e.g. 1 or 2.
batch.dst_neg_label: size = len(dst_neg_index), and each value is edge_label corresponding to dst_neg_index, e.g. 1 or 2. This indicates that dst_neg_index is a negative sampling corresponding to which edge_label.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
If we use triplet mode Negative Sampling in LinkNeighborLoader, LinkNeighborLoader does not return the edge label corresponding to
dst_pos_index
anddst_neg_index
. I want the edge label.The following code is the example.
I expected the following attr.
batch.dst_pos_label
: size = len(dst_pos_index), and each value isedge_label
corresponding todst_pos_index
, e.g. 1 or 2.batch.dst_neg_label
: size = len(dst_neg_index), and each value isedge_label
corresponding todst_neg_index
, e.g. 1 or 2. This indicates thatdst_neg_index
is a negative sampling corresponding to whichedge_label
.How can I achieve this?
Thank you
Beta Was this translation helpful? Give feedback.
All reactions