Skip to content

Add trim_to_layer support & relevant examples #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions examples/hetero/hierarchical_sage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import torch
import torch.nn.functional as F
from tqdm import tqdm

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import HeteroConv, Linear, SAGEConv
from torch_geometric.utils import trim_to_layer
import graphlearn_torch as glt


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transforms = [T.ToUndirected(merge=True)]
dataset = OGB_MAG(root='../../data', preprocess='metapath2vec',
transform=T.Compose(transforms))
data = dataset[0].to(device, 'x', 'y')


class HierarchicalHeteroGraphSage(torch.nn.Module):
def __init__(self, edge_types, hidden_channels, out_channels, num_layers):
super().__init__()

self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv(
{
edge_type: SAGEConv((-1, -1), hidden_channels)
for edge_type in edge_types
}, aggr='sum')
self.convs.append(conv)

self.lin = Linear(hidden_channels, out_channels)

def forward(self, x_dict, edge_index_dict, num_sampled_edges_dict,
num_sampled_nodes_dict):

for i, conv in enumerate(self.convs):
x_dict, edge_index_dict, _ = trim_to_layer(
layer=i,
num_sampled_nodes_per_hop=num_sampled_nodes_dict,
num_sampled_edges_per_hop=num_sampled_edges_dict,
x=x_dict,
edge_index=edge_index_dict,
)

x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: x.relu() for key, x in x_dict.items()}

return self.lin(x_dict['paper'])


model = HierarchicalHeteroGraphSage(
edge_types=data.edge_types,
hidden_channels=64,
out_channels=dataset.num_classes,
num_layers=2,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# init graphlearn_torch Dataset.
edge_dict, feature_dict = {}, {}
for etype in data.edge_types:
edge_dict[etype] = data[etype]['edge_index']
for ntype in data.node_types:
feature_dict[ntype] = data[ntype].x.clone(memory_format=torch.contiguous_format)

glt_dataset = glt.data.Dataset()
glt_dataset.init_graph(
edge_index=edge_dict,
graph_mode='ZERO_COPY'
)
glt_dataset.init_node_features(
node_feature_data=feature_dict,
split_ratio=0.2,
device_group_list=[glt.data.DeviceGroup(0, [0])]
)
glt_dataset.init_node_labels(node_label_data={'paper': data['paper'].y})

train_idx = data['paper'].train_mask.nonzero(as_tuple=False).view(-1)
train_loader = glt.loader.NeighborLoader(glt_dataset,
[10] * 2,
('paper', train_idx),
batch_size=1024,
shuffle=True,
device=device)
val_idx = data['paper'].val_mask.nonzero(as_tuple=False).view(-1)
val_loader = glt.loader.NeighborLoader(glt_dataset,
[10] * 2,
('paper', val_idx),
batch_size=1024,
device=device)


def train():
model.train()

total_examples = total_loss = 0
for batch in tqdm(train_loader):
batch = batch.to(device)
optimizer.zero_grad()

out = model(
batch.x_dict,
batch.edge_index_dict,
num_sampled_nodes_dict=batch.num_sampled_nodes,
num_sampled_edges_dict=batch.num_sampled_edges,
)

batch_size = batch['paper'].batch_size
loss = F.cross_entropy(out[:batch_size], batch['paper'].y[:batch_size])
loss.backward()
optimizer.step()

total_examples += batch_size
total_loss += float(loss) * batch_size

return total_loss / total_examples


@torch.no_grad()
def test(loader):
model.eval()

total_examples = total_correct = 0
for batch in tqdm(loader):
batch = batch.to(device)
out = model(
batch.x_dict,
batch.edge_index_dict,
num_sampled_nodes_dict=batch.num_sampled_nodes,
num_sampled_edges_dict=batch.num_sampled_edges,
)

batch_size = batch['paper'].batch_size
pred = out[:batch_size].argmax(dim=-1)
total_examples += batch_size
total_correct += int((pred == batch['paper'].y[:batch_size]).sum())

return total_correct / total_examples


for epoch in range(1, 6):
loss = train()
val_acc = test(val_loader)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')
109 changes: 109 additions & 0 deletions examples/train_sage_prod_with_trim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import os
import time
import torch

import numpy as np
import os.path as osp
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F

from numpy import genfromtxt

from torch_geometric.nn import GraphSAGE
from ogb.nodeproppred import PygNodePropPredDataset
from tqdm import tqdm

import graphlearn_torch as glt


def run(rank, glt_ds, train_idx,
num_features, num_classes, trimmed):

train_loader = glt.loader.NeighborLoader(glt_ds,
[10, 10, 10],
train_idx,
batch_size=1024,
shuffle=True,
device=torch.device(rank))
print(f'Rank {rank} build graphlearn_torch NeighborLoader Done.')
model = GraphSAGE(
in_channels=num_features,
hidden_channels=256,
num_layers=3,
out_channels=num_classes,
).to(rank)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1, 10):
model.train()
start = time.time()
total_examples = total_loss = 0
for batch in tqdm(train_loader):
optimizer.zero_grad()
if trimmed:
out = model(
batch.x, batch.edge_index,
num_sampled_nodes_per_hop=batch.num_sampled_nodes,
num_sampled_edges_per_hop=batch.num_sampled_edges,
)[:batch.batch_size].log_softmax(dim=-1)
else:
out = model(
batch.x, batch.edge_index
)[:batch.batch_size].log_softmax(dim=-1)
loss = F.nll_loss(out, batch.y[:batch.batch_size])
loss.backward()
optimizer.step()
total_examples += batch.batch_size
total_loss += float(loss) * batch.batch_size
end = time.time()

print(f'Epoch: {epoch:03d}, Loss: {(total_loss / total_examples):.4f},',
f'Epoch Time: {end - start}')



if __name__ == '__main__':
world_size = torch.cuda.device_count()
start = time.time()
root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'products')
dataset = PygNodePropPredDataset('ogbn-products', root)
split_idx = dataset.get_idx_split()
data = dataset[0]
train_idx = split_idx['train']
print(f'Load data cost {time.time()-start} s.')

start = time.time()
print('Build graphlearn_torch dataset...')
glt_dataset = glt.data.Dataset()
glt_dataset.init_graph(
edge_index=data.edge_index,
graph_mode='CUDA',
directed=False
)
glt_dataset.init_node_features(
node_feature_data=data.x,
sort_func=glt.data.sort_by_in_degree,
split_ratio=1,
device_group_list=[glt.data.DeviceGroup(0, [0])],
)
glt_dataset.init_node_labels(node_label_data=data.y)
print(f'Build graphlearn_torch csr_topo and feature cost {time.time() - start} s.')

run(0, glt_dataset, train_idx, 100, 47, True)
12 changes: 12 additions & 0 deletions graphlearn_torch/python/distributed/dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def _collate_fn(
if is_hetero:
node_dict, row_dict, col_dict, edge_dict = {}, {}, {}, {}
nfeat_dict, efeat_dict = {}, {}
num_sampled_nodes_dict, num_sampled_edges_dict = {}, {}

for ntype in self._node_types:
ids_key = f'{as_str(ntype)}.ids'
Expand All @@ -313,6 +314,9 @@ def _collate_fn(
nfeat_key = f'{as_str(ntype)}.nfeats'
if nfeat_key in msg:
nfeat_dict[ntype] = msg[nfeat_key].to(self.to_device)
num_sampled_nodes_key = f'{as_str(ntype)}.num_sampled_nodes'
if num_sampled_nodes_key in msg:
num_sampled_nodes_dict[ntype] = msg[num_sampled_nodes_key]

for etype_str, rev_etype in self._etype_str_to_rev.items():
rows_key = f'{etype_str}.rows'
Expand All @@ -324,6 +328,9 @@ def _collate_fn(
eids_key = f'{etype_str}.eids'
if eids_key in msg:
edge_dict[rev_etype] = msg[eids_key].to(self.to_device)
num_sampled_edges_key = f'{etype_str}.num_sampled_edges'
if num_sampled_edges_key in msg:
num_sampled_edges_dict[rev_etype] = msg[num_sampled_edges_key]
efeat_key = f'{etype_str}.efeats'
if efeat_key in msg:
efeat_dict[rev_etype] = msg[efeat_key].to(self.to_device)
Expand Down Expand Up @@ -351,6 +358,8 @@ def _collate_fn(
output = HeteroSamplerOutput(node_dict, row_dict, col_dict,
edge_dict if len(edge_dict) else None,
batch_dict,
num_sampled_nodes=num_sampled_nodes_dict,
num_sampled_edges=num_sampled_edges_dict,
edge_types=self._reversed_edge_types,
input_type=self._input_type,
device=self.to_device,
Expand All @@ -363,6 +372,8 @@ def _collate_fn(
rows = msg['rows'].to(self.to_device)
cols = msg['cols'].to(self.to_device)
eids = msg['eids'].to(self.to_device) if 'eids' in msg else None
num_sampled_nodes = msg['num_sampled_nodes'] if 'num_sampled_nodes' in msg else None
num_sampled_edges = msg['num_sampled_edges'] if 'num_sampled_edges' in msg else None

nfeats = msg['nfeats'].to(self.to_device) if 'nfeats' in msg else None
efeats = msg['efeats'].to(self.to_device) if 'efeats' in msg else None
Expand All @@ -377,6 +388,7 @@ def _collate_fn(

# The edge index should be reversed.
output = SamplerOutput(ids, cols, rows, eids, batch,
num_sampled_nodes, num_sampled_edges,
device=self.to_device, metadata=metadata)
res_data = to_data(output, batch_labels, nfeats, efeats)

Expand Down
Loading