How to properly perform inference on link prediction task for unseen graph? #8899
Replies: 2 comments 14 replies
-
During validation and test are the edges in the prediction set also used for message passing? That might be the reason why you might be overfitting on the train/val/test data. |
Beta Was this translation helpful? Give feedback.
-
Hi again, I just realized that during training/validating, both edge types are included to update the node embeddings ( For inference, I need to predict the entire set of |
Beta Was this translation helpful? Give feedback.
-
Hi, I successfully trained a model on a set of heterogeneous graphs for the link prediction task. I used PYG x lightning for training, and the AUC scores for both val and test > 0.9. The data is split into train/val/test with 720/90/90 graphs. However, when I compute the pair-wise dot products at the inference state for new unseen graphs, it seems not to give me any useful results, as the scores for the true edges are very low once applying a
sigmoid
function and a thresholdT= 0.9
. Could you identify what I did wrong here? My graphs have one node type and two edge types (type1
andtype2
). During training, both edge types are available. In the inference stage, I want to predict theedge_index
oftype2
.Model definitions:
Example of the
training_step
(similar implementations forvalidation_step
andtest_step
):In the inference step, I first delete the
type2
edges in the new graphs, then compute the pair-wise dot products after feeding the node features into the HGNN part. I'm looping through each graph for the convenience of testing:Getting results:
Beta Was this translation helpful? Give feedback.
All reactions