Heterogeneous dataloader with batches #6350
Replies: 2 comments 5 replies
-
How do you create your sub-graphs? Keep in mind that the edge indices of each subgraph needs to be mapped to new indices (from |
Beta Was this translation helpful? Give feedback.
-
I understand. But I have a follow-up question then: once I trained the model and made predictions on the re-indexed nodes, how do I get back the original indices to know which node has been predicted? For example, one of the node types is booking, and another is client. I stripped both of their unique IDs in the node.x features, and can only count on their index to identify them. So how can I then see what the forecasts are for all the bookings of a certain client (client which very likely may appear in different subgraphs). |
Beta Was this translation helpful? Give feedback.
-
Hi,
I am spending a lot of time using pytorch_geometric lately, and want to take the occasion of my post here to thank you guys for all the work you've done in the development.
I am modeling a phenomenon with 5 types of nodes and 7 types of edges. I've started by creating a full graph, but the model training crashes, so I'm looking into batches.
I first started looking into
HGTLoader
andNeighboorLoader
, but they picked nodes a bit too much at random. Since my data is rather organic (the phenomenon will basically always have the same sort of shape similar to the below), I am now creating a "mini-graph" for each phenomenon.Once I created all the mini_graphs, I group them in a DataLoader
train_loader = DataLoader(train_data, batch_size=100, shuffle=True)
, ensuring that all of my phenomenons (with all their node types) would be complete in each batch.The problem I have now in launching the training loop is that the batches don't contain all the nodes of the graph, and therefore I get an error on the node_indexes being too large:
Indeed, since my whole graph contains 9K nodes of type 1, i might have, in edge_index, numbers higher than 99 even though the graph presented in the batch only contains 100 nodes of type 1.
How do you think I could get out of this predicament?
Beta Was this translation helpful? Give feedback.
All reactions