From 3fc5a6221a97c659cc56b6943a2f224300c47fb4 Mon Sep 17 00:00:00 2001 From: huxleyhu Date: Mon, 29 May 2023 08:50:13 +0000 Subject: [PATCH 1/3] add trim --- examples/train_sage_prod_with_trim.py | 113 ++++++++++++++++++ graphlearn_torch/python/loader/transform.py | 7 ++ graphlearn_torch/python/sampler/base.py | 4 + .../python/sampler/neighbor_sampler.py | 14 ++- graphlearn_torch/python/utils/common.py | 8 +- 5 files changed, 144 insertions(+), 2 deletions(-) create mode 100644 examples/train_sage_prod_with_trim.py diff --git a/examples/train_sage_prod_with_trim.py b/examples/train_sage_prod_with_trim.py new file mode 100644 index 00000000..67885f98 --- /dev/null +++ b/examples/train_sage_prod_with_trim.py @@ -0,0 +1,113 @@ +# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import time +import torch + +import numpy as np +import os.path as osp +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn.functional as F + +from numpy import genfromtxt + +from torch_geometric.nn import GraphSAGE +from ogb.nodeproppred import PygNodePropPredDataset +from tqdm import tqdm + +import graphlearn_torch as glt + + +def run(rank, glt_ds, train_idx, + num_features, num_classes, trimmed): + + train_loader = glt.loader.NeighborLoader(glt_ds, + [10, 10, 10], + train_idx, + batch_size=1024, + shuffle=True, + device=torch.device(rank)) + print(f'Rank {rank} build graphlearn_torch NeighborLoader Done.') + model = GraphSAGE( + in_channels=num_features, + hidden_channels=256, + num_layers=3, + out_channels=num_classes, + ).to(rank) + + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + for name, params in model.named_parameters(): + print(name, ":", params.size()) + + for epoch in range(1, 10): + model.train() + start = time.time() + total_examples = total_loss = 0 + for batch in tqdm(train_loader): + # print(batch.num_sampled_nodes) + optimizer.zero_grad() + if trimmed: + out = model( + batch.x, batch.edge_index, + num_sampled_nodes_per_hop=batch.num_sampled_nodes, + num_sampled_edges_per_hop=batch.num_sampled_edges, + )[:batch.batch_size].log_softmax(dim=-1) + else: + out = model( + batch.x, batch.edge_index + )[:batch.batch_size].log_softmax(dim=-1) + loss = F.nll_loss(out, batch.y[:batch.batch_size]) + loss.backward() + optimizer.step() + total_examples += batch.batch_size + total_loss += float(loss) * batch.batch_size + end = time.time() + + print(f'Epoch: {epoch:03d}, Loss: {(total_loss / total_examples):.4f},', + f'Epoch Time: {end - start}') + + + +if __name__ == '__main__': + world_size = torch.cuda.device_count() + start = time.time() + root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'products') + dataset = PygNodePropPredDataset('ogbn-products', root) + split_idx = dataset.get_idx_split() + data = dataset[0] + train_idx = split_idx['train'] + print(f'Load data cost {time.time()-start} s.') + + start = time.time() + print('Build graphlearn_torch dataset...') + glt_dataset = glt.data.Dataset() + glt_dataset.init_graph( + edge_index=data.edge_index, + graph_mode='CUDA', + directed=False + ) + glt_dataset.init_node_features( + node_feature_data=data.x, + sort_func=glt.data.sort_by_in_degree, + split_ratio=1, + device_group_list=[glt.data.DeviceGroup(0, [0])], + ) + glt_dataset.init_node_labels(node_label_data=data.y) + print(f'Build graphlearn_torch csr_topo and feature cost {time.time() - start} s.') + + run(0, glt_dataset, train_idx, 100, 47, True) \ No newline at end of file diff --git a/graphlearn_torch/python/loader/transform.py b/graphlearn_torch/python/loader/transform.py index 01ecebbc..76961ba2 100644 --- a/graphlearn_torch/python/loader/transform.py +++ b/graphlearn_torch/python/loader/transform.py @@ -38,6 +38,9 @@ def to_data( data.batch = sampler_out.batch data.batch_size = sampler_out.batch.numel() if data.batch is not None else 0 + data.num_sampled_nodes = sampler_out.num_sampled_nodes + data.num_sampled_edges = sampler_out.num_sampled_edges + # update meta data if isinstance(sampler_out.metadata, dict): for k, v in sampler_out.metadata.items(): @@ -83,6 +86,10 @@ def to_hetero_data( if batch_label_dict is not None: data[k].y = batch_label_dict.get(k, None) + # update num_sampled_nodes & num_sampled_edges + data.num_sampled_nodes = hetero_sampler_out.num_sampled_nodes + data.num_sampled_edges = hetero_sampler_out.num_sampled_edges + # update meta data input_type = hetero_sampler_out.input_type if isinstance(hetero_sampler_out.metadata, dict): diff --git a/graphlearn_torch/python/sampler/base.py b/graphlearn_torch/python/sampler/base.py index c4ebcd5a..980cbeb3 100644 --- a/graphlearn_torch/python/sampler/base.py +++ b/graphlearn_torch/python/sampler/base.py @@ -235,6 +235,8 @@ class SamplerOutput(CastMixin): col: torch.Tensor edge: Optional[torch.Tensor] = None batch: Optional[torch.Tensor] = None + num_sampled_nodes: Optional[List[int]] = None + num_sampled_edges: Optional[List[int]] = None device: Optional[torch.device] = None metadata: Optional[Any] = None @@ -282,6 +284,8 @@ class HeteroSamplerOutput(CastMixin): col: Dict[EdgeType, torch.Tensor] edge: Optional[Dict[EdgeType, torch.Tensor]] = None batch: Optional[Dict[NodeType, torch.Tensor]] = None + num_sampled_nodes: Optional[Dict[NodeType, List[int]]] = None + num_sampled_edges: Optional[Dict[EdgeType, List[int]]] = None edge_types: Optional[List[EdgeType]] = None input_type: Optional[Union[NodeType, EdgeType]] = None device: Optional[torch.device] = None diff --git a/graphlearn_torch/python/sampler/neighbor_sampler.py b/graphlearn_torch/python/sampler/neighbor_sampler.py index 8348606f..93fede9a 100644 --- a/graphlearn_torch/python/sampler/neighbor_sampler.py +++ b/graphlearn_torch/python/sampler/neighbor_sampler.py @@ -23,7 +23,7 @@ from ..typing import NodeType, EdgeType, NumNeighbors, reverse_edge_type from ..utils import ( merge_dict, merge_hetero_sampler_output, format_hetero_sampler_output, - id2idx + id2idx, count_dict ) @@ -165,9 +165,11 @@ def _sample_from_nodes( [dst_index, src_index]. """ out_nodes, out_rows, out_cols, out_edges = [], [], [], [] + num_sampled_nodes, num_sampled_edges = [], [] inducer = self.get_inducer(input_seeds.numel()) srcs = inducer.init_node(input_seeds) batch = srcs + num_sampled_nodes.append(input_seeds.numel()) out_nodes.append(srcs) for req_num in self.num_neighbors: out_nbrs = self.sample_one_hop(srcs, req_num) @@ -178,6 +180,8 @@ def _sample_from_nodes( out_cols.append(cols) if out_nbrs.edge is not None: out_edges.append(out_nbrs.edge) + num_sampled_nodes.append(nodes.size(0)) + num_sampled_edges.append(cols.size(0)) srcs = nodes return SamplerOutput( @@ -186,6 +190,8 @@ def _sample_from_nodes( col=torch.cat(out_rows), edge=(torch.cat(out_edges) if out_edges else None), batch=batch, + num_sampled_nodes=num_sampled_nodes, + num_sampled_edges=num_sampled_edges, device=self.device ) @@ -209,7 +215,9 @@ def _hetero_sample_from_nodes( src_dict = inducer.init_node(input_seeds_dict) batch = src_dict out_nodes, out_rows, out_cols, out_edges = {}, {}, {}, {} + num_sampled_nodes, num_sampled_edges = {}, {} merge_dict(src_dict, out_nodes) + count_dict(src_dict, num_sampled_nodes, 1) for i in range(self.num_hops): nbr_dict, edge_dict = {}, {} for etype in self.edge_types: @@ -225,6 +233,8 @@ def _hetero_sample_from_nodes( merge_dict(rows_dict, out_rows) merge_dict(cols_dict, out_cols) merge_dict(edge_dict, out_edges) + count_dict(nodes_dict, num_sampled_nodes, i + 2) + count_dict(cols_dict, num_sampled_edges, i + 1) src_dict = nodes_dict for etype, rows in out_rows.items(): @@ -248,6 +258,8 @@ def _hetero_sample_from_nodes( col=res_cols, edge=(res_edges if len(res_edges) else None), batch=batch, + num_sampled_nodes=num_sampled_nodes, + num_sampled_edges=num_sampled_edges, edge_types=self.edge_types, device=self.device ) diff --git a/graphlearn_torch/python/utils/common.py b/graphlearn_torch/python/utils/common.py index e67253fa..fc0b40e0 100644 --- a/graphlearn_torch/python/utils/common.py +++ b/graphlearn_torch/python/utils/common.py @@ -15,7 +15,7 @@ import os import socket -from typing import Any, Dict +from typing import Any, Dict, Callable, Optional from ..typing import reverse_edge_type from .tensor import id2idx @@ -32,6 +32,12 @@ def merge_dict(in_dict: Dict[Any, Any], out_dict: Dict[Any, Any]): vals.append(v) out_dict[k] = vals +def count_dict(in_dict: Dict[Any, Any], out_dict: Dict[Any, Any], target_len): + for k, v in in_dict.items(): + vals = out_dict.get(k, []) + vals += [0] * (target_len - len(vals) - 1) + vals.append(len(v)) + out_dict[k] = vals def get_free_port(host: str = 'localhost') -> int: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) From 0f024a29c049db83061b56a96ed135a5f56eda8c Mon Sep 17 00:00:00 2001 From: huxleyhu Date: Wed, 31 May 2023 08:55:11 +0000 Subject: [PATCH 2/3] add dist&hetero stats --- examples/hetero/hierarchical_sage.py | 149 ++++++++++++++++++ .../python/distributed/dist_loader.py | 14 ++ .../distributed/dist_neighbor_sampler.py | 19 +++ graphlearn_torch/python/loader/transform.py | 12 ++ .../python/sampler/neighbor_sampler.py | 3 +- 5 files changed, 196 insertions(+), 1 deletion(-) create mode 100644 examples/hetero/hierarchical_sage.py diff --git a/examples/hetero/hierarchical_sage.py b/examples/hetero/hierarchical_sage.py new file mode 100644 index 00000000..a91fbbc8 --- /dev/null +++ b/examples/hetero/hierarchical_sage.py @@ -0,0 +1,149 @@ +import argparse + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +import torch_geometric.transforms as T +from torch_geometric.datasets import OGB_MAG +from torch_geometric.nn import HeteroConv, Linear, SAGEConv +from torch_geometric.utils import trim_to_layer +import graphlearn_torch as glt + + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +transforms = [T.ToUndirected(merge=True)] +dataset = OGB_MAG(root='../../data', preprocess='metapath2vec', + transform=T.Compose(transforms)) +data = dataset[0].to(device, 'x', 'y') + + +class HierarchicalHeteroGraphSage(torch.nn.Module): + def __init__(self, edge_types, hidden_channels, out_channels, num_layers): + super().__init__() + + self.convs = torch.nn.ModuleList() + for _ in range(num_layers): + conv = HeteroConv( + { + edge_type: SAGEConv((-1, -1), hidden_channels) + for edge_type in edge_types + }, aggr='sum') + self.convs.append(conv) + + self.lin = Linear(hidden_channels, out_channels) + + def forward(self, x_dict, edge_index_dict, num_sampled_edges_dict, + num_sampled_nodes_dict): + + for i, conv in enumerate(self.convs): + x_dict, edge_index_dict, _ = trim_to_layer( + layer=i, + num_sampled_nodes_per_hop=num_sampled_nodes_dict, + num_sampled_edges_per_hop=num_sampled_edges_dict, + x=x_dict, + edge_index=edge_index_dict, + ) + + x_dict = conv(x_dict, edge_index_dict) + x_dict = {key: x.relu() for key, x in x_dict.items()} + + return self.lin(x_dict['paper']) + + +model = HierarchicalHeteroGraphSage( + edge_types=data.edge_types, + hidden_channels=64, + out_channels=dataset.num_classes, + num_layers=2, +).to(device) + +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + # init graphlearn_torch Dataset. +edge_dict, feature_dict = {}, {} +for etype in data.edge_types: + edge_dict[etype] = data[etype]['edge_index'] +for ntype in data.node_types: + feature_dict[ntype] = data[ntype].x.clone(memory_format=torch.contiguous_format) + +glt_dataset = glt.data.Dataset() +glt_dataset.init_graph( + edge_index=edge_dict, + graph_mode='ZERO_COPY' +) +glt_dataset.init_node_features( + node_feature_data=feature_dict, + split_ratio=0.2, + device_group_list=[glt.data.DeviceGroup(0, [0])] +) +glt_dataset.init_node_labels(node_label_data={'paper': data['paper'].y}) + +train_idx = data['paper'].train_mask.nonzero(as_tuple=False).view(-1) +train_loader = glt.loader.NeighborLoader(glt_dataset, + [10] * 2, + ('paper', train_idx), + batch_size=1024, + shuffle=True, + device=device) +val_idx = data['paper'].val_mask.nonzero(as_tuple=False).view(-1) +val_loader = glt.loader.NeighborLoader(glt_dataset, + [10] * 2, + ('paper', val_idx), + batch_size=1024, + device=device) + + +def train(): + model.train() + + total_examples = total_loss = 0 + for batch in tqdm(train_loader): + batch = batch.to(device) + optimizer.zero_grad() + + out = model( + batch.x_dict, + batch.edge_index_dict, + num_sampled_nodes_dict=batch.num_sampled_nodes, + num_sampled_edges_dict=batch.num_sampled_edges, + ) + + batch_size = batch['paper'].batch_size + loss = F.cross_entropy(out[:batch_size], batch['paper'].y[:batch_size]) + loss.backward() + optimizer.step() + + total_examples += batch_size + total_loss += float(loss) * batch_size + + return total_loss / total_examples + + +@torch.no_grad() +def test(loader): + model.eval() + + total_examples = total_correct = 0 + for batch in tqdm(loader): + batch = batch.to(device) + out = model( + batch.x_dict, + batch.edge_index_dict, + num_sampled_nodes_dict=batch.num_sampled_nodes, + num_sampled_edges_dict=batch.num_sampled_edges, + ) + + batch_size = batch['paper'].batch_size + pred = out[:batch_size].argmax(dim=-1) + total_examples += batch_size + total_correct += int((pred == batch['paper'].y[:batch_size]).sum()) + + return total_correct / total_examples + + +for epoch in range(1, 6): + loss = train() + val_acc = test(val_loader) + print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}') diff --git a/graphlearn_torch/python/distributed/dist_loader.py b/graphlearn_torch/python/distributed/dist_loader.py index e79176e5..215bcdb8 100644 --- a/graphlearn_torch/python/distributed/dist_loader.py +++ b/graphlearn_torch/python/distributed/dist_loader.py @@ -305,6 +305,7 @@ def _collate_fn( if is_hetero: node_dict, row_dict, col_dict, edge_dict = {}, {}, {}, {} nfeat_dict, efeat_dict = {}, {} + num_sampled_nodes_dict, num_sampled_edges_dict = {}, {} for ntype in self._node_types: ids_key = f'{as_str(ntype)}.ids' @@ -313,6 +314,9 @@ def _collate_fn( nfeat_key = f'{as_str(ntype)}.nfeats' if nfeat_key in msg: nfeat_dict[ntype] = msg[nfeat_key].to(self.to_device) + num_sampled_nodes_key = f'{as_str(ntype)}.num_sampled_nodes' + if num_sampled_nodes_key in msg: + num_sampled_nodes_dict[ntype] = msg[num_sampled_nodes_key].to(self.to_device) for etype_str, rev_etype in self._etype_str_to_rev.items(): rows_key = f'{etype_str}.rows' @@ -324,6 +328,9 @@ def _collate_fn( eids_key = f'{etype_str}.eids' if eids_key in msg: edge_dict[rev_etype] = msg[eids_key].to(self.to_device) + num_sampled_edges_key = f'{etype_str}.num_sampled_edges' + if num_sampled_edges_key in msg: + num_sampled_edges_dict[rev_etype] = msg[num_sampled_edges_key].to(self.to_device) efeat_key = f'{etype_str}.efeats' if efeat_key in msg: efeat_dict[rev_etype] = msg[efeat_key].to(self.to_device) @@ -351,6 +358,8 @@ def _collate_fn( output = HeteroSamplerOutput(node_dict, row_dict, col_dict, edge_dict if len(edge_dict) else None, batch_dict, + num_sampled_nodes=num_sampled_nodes_dict, + num_sampled_edges=num_sampled_edges_dict, edge_types=self._reversed_edge_types, input_type=self._input_type, device=self.to_device, @@ -363,6 +372,10 @@ def _collate_fn( rows = msg['rows'].to(self.to_device) cols = msg['cols'].to(self.to_device) eids = msg['eids'].to(self.to_device) if 'eids' in msg else None + num_sampled_nodes = msg['num_sampled_nodes'].to(self.to_device) \ + if 'num_sampled_nodes' in msg else None + num_sampled_edges = msg['num_sampled_edges'].to(self.to_device) \ + if 'num_sampled_edges' in msg else None nfeats = msg['nfeats'].to(self.to_device) if 'nfeats' in msg else None efeats = msg['efeats'].to(self.to_device) if 'efeats' in msg else None @@ -377,6 +390,7 @@ def _collate_fn( # The edge index should be reversed. output = SamplerOutput(ids, cols, rows, eids, batch, + num_sampled_nodes, num_sampled_edges, device=self.to_device, metadata=metadata) res_data = to_data(output, batch_labels, nfeats, efeats) diff --git a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py index c11c16ae..c8f68447 100644 --- a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py +++ b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py @@ -31,6 +31,7 @@ from ..utils import ( get_available_device, ensure_device, merge_dict, id2idx, merge_hetero_sampler_output, format_hetero_sampler_output, + count_dict ) from .dist_dataset import DistDataset @@ -265,7 +266,9 @@ async def _sample_from_nodes( assert input_type is not None src_dict = inducer.init_node({input_type: input_seeds}) out_nodes, out_rows, out_cols, out_edges = {}, {}, {}, {} + num_sampled_nodes, num_sampled_edges = {}, {} merge_dict(src_dict, out_nodes) + count_dict(src_dict, num_sampled_nodes) for i in range(self.num_hops): task_dict, nbr_dict, edge_dict = {}, {}, {} @@ -285,6 +288,8 @@ async def _sample_from_nodes( merge_dict(rows_dict, out_rows) merge_dict(cols_dict, out_cols) merge_dict(edge_dict, out_edges) + count_dict(nodes_dict, num_sampled_nodes) + count_dict(cols_dict, num_sampled_edges) src_dict = nodes_dict sample_output = HeteroSamplerOutput( @@ -295,13 +300,17 @@ async def _sample_from_nodes( {etype: torch.cat(eids) for etype, eids in out_edges.items()} if self.with_edge else None ), + num_sampled_nodes=num_sampled_nodes, + num_sampled_edges=num_sampled_edges, input_type=input_type, metadata={} ) else: srcs = inducer.init_node(input_seeds) out_nodes, out_edges = [], [] + num_sampled_nodes, num_sampled_edges = [], [] out_nodes.append(srcs) + num_sampled_nodes.append(srcs.size(0)) # Sample subgraph. for req_num in self.num_neighbors: output: NeighborOutput = await self._sample_one_hop(srcs, req_num, None) @@ -309,6 +318,8 @@ async def _sample_from_nodes( inducer.induce_next(srcs, output.nbr, output.nbr_num) out_nodes.append(nodes) out_edges.append((rows, cols, output.edge)) + num_sampled_nodes.append(nodes.size(0)) + num_sampled_edges.append(cols.size(0)) srcs = nodes sample_output = SamplerOutput( @@ -316,6 +327,8 @@ async def _sample_from_nodes( row=torch.cat([e[0] for e in out_edges]), col=torch.cat([e[1] for e in out_edges]), edge=(torch.cat([e[2] for e in out_edges]) if self.with_edge else None), + num_sampled_nodes=num_sampled_nodes, + num_sampled_edges=num_sampled_edges, metadata={} ) # Reclaim inducer into pool. @@ -615,12 +628,16 @@ async def _colloate_fn( if is_hetero: for ntype, nodes in output.node.items(): result_map[f'{as_str(ntype)}.ids'] = nodes + result_map[f'{as_str(ntype)}.num_sampled_nodes'] = \ + output.num_sampled_nodes[ntype] for etype, rows in output.row.items(): etype_str = as_str(etype) result_map[f'{etype_str}.rows'] = rows result_map[f'{etype_str}.cols'] = output.col[etype] if self.with_edge: result_map[f'{etype_str}.eids'] = output.edge[etype] + result_map[f'{etype_str}.num_sampled_edges'] = \ + output.num_sampled_edges[etype] # Collect node labels of input node type. input_type = output.input_type assert input_type is not None @@ -652,6 +669,8 @@ async def _colloate_fn( result_map['ids'] = output.node result_map['rows'] = output.row result_map['cols'] = output.col + result_map['num_sampled_nodes'] = output.num_sampled_nodes + result_map['num_sampled_edges'] = output.num_sampled_edges if self.with_edge: result_map['eids'] = output.edge # Collect node labels. diff --git a/graphlearn_torch/python/loader/transform.py b/graphlearn_torch/python/loader/transform.py index 76961ba2..a7b0f31b 100644 --- a/graphlearn_torch/python/loader/transform.py +++ b/graphlearn_torch/python/loader/transform.py @@ -65,6 +65,8 @@ def to_hetero_data( ) -> HeteroData: data = HeteroData(**kwargs) edge_index_dict = hetero_sampler_out.get_edge_index() + num_hops = max(map( + lambda x: len(x), list(hetero_sampler_out.num_sampled_edges.values()))) # edges for k, v in edge_index_dict.items(): data[k].edge_index = v @@ -72,12 +74,22 @@ def to_hetero_data( data[k].edge = hetero_sampler_out.edge.get(k, None) if edge_feat_dict is not None: data[k].edge_attr = edge_feat_dict.get(k, None) + if k not in hetero_sampler_out.num_sampled_edges: + hetero_sampler_out.num_sampled_edges[k] = [0] * num_hops + else: + hetero_sampler_out.num_sampled_edges[k] += \ + [0] * (num_hops - len(hetero_sampler_out.num_sampled_edges[k])) # nodes for k, v in hetero_sampler_out.node.items(): data[k].node = v if node_feat_dict is not None: data[k].x = node_feat_dict.get(k, None) + if k not in hetero_sampler_out.num_sampled_nodes: + hetero_sampler_out.num_sampled_nodes[k] = [0] * (num_hops + 1) + else: + hetero_sampler_out.num_sampled_nodes[k] += \ + [0] * (num_hops + 1 - len(hetero_sampler_out.num_sampled_nodes[k])) # seed nodes for k, v in hetero_sampler_out.batch.items(): diff --git a/graphlearn_torch/python/sampler/neighbor_sampler.py b/graphlearn_torch/python/sampler/neighbor_sampler.py index 93fede9a..1021cad8 100644 --- a/graphlearn_torch/python/sampler/neighbor_sampler.py +++ b/graphlearn_torch/python/sampler/neighbor_sampler.py @@ -259,7 +259,8 @@ def _hetero_sample_from_nodes( edge=(res_edges if len(res_edges) else None), batch=batch, num_sampled_nodes=num_sampled_nodes, - num_sampled_edges=num_sampled_edges, + num_sampled_edges={ + reverse_edge_type(k) : v for k, v in num_sampled_edges.items()}, edge_types=self.edge_types, device=self.device ) From 51e31d7698246a78f7973a2454b910f1cbf19c6d Mon Sep 17 00:00:00 2001 From: huxleyhu Date: Wed, 31 May 2023 14:12:30 +0000 Subject: [PATCH 3/3] add test for dist --- examples/hetero/hierarchical_sage.py | 16 +++++++++-- examples/train_sage_prod_with_trim.py | 4 --- .../python/distributed/dist_loader.py | 10 +++---- .../distributed/dist_neighbor_sampler.py | 27 ++++++++++++------- graphlearn_torch/python/loader/transform.py | 19 ++++++++----- graphlearn_torch/python/sampler/base.py | 8 +++--- .../python/sampler/neighbor_sampler.py | 6 +++-- test/python/test_dist_neighbor_loader.py | 17 ++++++++++++ 8 files changed, 73 insertions(+), 34 deletions(-) diff --git a/examples/hetero/hierarchical_sage.py b/examples/hetero/hierarchical_sage.py index a91fbbc8..e5c5264d 100644 --- a/examples/hetero/hierarchical_sage.py +++ b/examples/hetero/hierarchical_sage.py @@ -1,4 +1,17 @@ -import argparse +# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import torch import torch.nn.functional as F @@ -13,7 +26,6 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - transforms = [T.ToUndirected(merge=True)] dataset = OGB_MAG(root='../../data', preprocess='metapath2vec', transform=T.Compose(transforms)) diff --git a/examples/train_sage_prod_with_trim.py b/examples/train_sage_prod_with_trim.py index 67885f98..9855cc88 100644 --- a/examples/train_sage_prod_with_trim.py +++ b/examples/train_sage_prod_with_trim.py @@ -51,15 +51,11 @@ def run(rank, glt_ds, train_idx, optimizer = torch.optim.Adam(model.parameters(), lr=0.01) - for name, params in model.named_parameters(): - print(name, ":", params.size()) - for epoch in range(1, 10): model.train() start = time.time() total_examples = total_loss = 0 for batch in tqdm(train_loader): - # print(batch.num_sampled_nodes) optimizer.zero_grad() if trimmed: out = model( diff --git a/graphlearn_torch/python/distributed/dist_loader.py b/graphlearn_torch/python/distributed/dist_loader.py index 215bcdb8..b4f7ee96 100644 --- a/graphlearn_torch/python/distributed/dist_loader.py +++ b/graphlearn_torch/python/distributed/dist_loader.py @@ -316,7 +316,7 @@ def _collate_fn( nfeat_dict[ntype] = msg[nfeat_key].to(self.to_device) num_sampled_nodes_key = f'{as_str(ntype)}.num_sampled_nodes' if num_sampled_nodes_key in msg: - num_sampled_nodes_dict[ntype] = msg[num_sampled_nodes_key].to(self.to_device) + num_sampled_nodes_dict[ntype] = msg[num_sampled_nodes_key] for etype_str, rev_etype in self._etype_str_to_rev.items(): rows_key = f'{etype_str}.rows' @@ -330,7 +330,7 @@ def _collate_fn( edge_dict[rev_etype] = msg[eids_key].to(self.to_device) num_sampled_edges_key = f'{etype_str}.num_sampled_edges' if num_sampled_edges_key in msg: - num_sampled_edges_dict[rev_etype] = msg[num_sampled_edges_key].to(self.to_device) + num_sampled_edges_dict[rev_etype] = msg[num_sampled_edges_key] efeat_key = f'{etype_str}.efeats' if efeat_key in msg: efeat_dict[rev_etype] = msg[efeat_key].to(self.to_device) @@ -372,10 +372,8 @@ def _collate_fn( rows = msg['rows'].to(self.to_device) cols = msg['cols'].to(self.to_device) eids = msg['eids'].to(self.to_device) if 'eids' in msg else None - num_sampled_nodes = msg['num_sampled_nodes'].to(self.to_device) \ - if 'num_sampled_nodes' in msg else None - num_sampled_edges = msg['num_sampled_edges'].to(self.to_device) \ - if 'num_sampled_edges' in msg else None + num_sampled_nodes = msg['num_sampled_nodes'] if 'num_sampled_nodes' in msg else None + num_sampled_edges = msg['num_sampled_edges'] if 'num_sampled_edges' in msg else None nfeats = msg['nfeats'].to(self.to_device) if 'nfeats' in msg else None efeats = msg['efeats'].to(self.to_device) if 'efeats' in msg else None diff --git a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py index c8f68447..3f451044 100644 --- a/graphlearn_torch/python/distributed/dist_neighbor_sampler.py +++ b/graphlearn_torch/python/distributed/dist_neighbor_sampler.py @@ -31,7 +31,7 @@ from ..utils import ( get_available_device, ensure_device, merge_dict, id2idx, merge_hetero_sampler_output, format_hetero_sampler_output, - count_dict + count_dict, convert_to_tensor ) from .dist_dataset import DistDataset @@ -268,7 +268,7 @@ async def _sample_from_nodes( out_nodes, out_rows, out_cols, out_edges = {}, {}, {}, {} num_sampled_nodes, num_sampled_edges = {}, {} merge_dict(src_dict, out_nodes) - count_dict(src_dict, num_sampled_nodes) + count_dict(src_dict, num_sampled_nodes, 1) for i in range(self.num_hops): task_dict, nbr_dict, edge_dict = {}, {}, {} @@ -288,8 +288,8 @@ async def _sample_from_nodes( merge_dict(rows_dict, out_rows) merge_dict(cols_dict, out_cols) merge_dict(edge_dict, out_edges) - count_dict(nodes_dict, num_sampled_nodes) - count_dict(cols_dict, num_sampled_edges) + count_dict(nodes_dict, num_sampled_nodes, i + 2) + count_dict(cols_dict, num_sampled_edges, i + 1) src_dict = nodes_dict sample_output = HeteroSamplerOutput( @@ -628,16 +628,20 @@ async def _colloate_fn( if is_hetero: for ntype, nodes in output.node.items(): result_map[f'{as_str(ntype)}.ids'] = nodes - result_map[f'{as_str(ntype)}.num_sampled_nodes'] = \ - output.num_sampled_nodes[ntype] + if output.num_sampled_nodes is not None: + if ntype in output.num_sampled_nodes: + result_map[f'{as_str(ntype)}.num_sampled_nodes'] = \ + torch.tensor(output.num_sampled_nodes[ntype], device=self.device) for etype, rows in output.row.items(): etype_str = as_str(etype) result_map[f'{etype_str}.rows'] = rows result_map[f'{etype_str}.cols'] = output.col[etype] if self.with_edge: result_map[f'{etype_str}.eids'] = output.edge[etype] - result_map[f'{etype_str}.num_sampled_edges'] = \ - output.num_sampled_edges[etype] + if output.num_sampled_edges is not None: + if etype in output.num_sampled_edges: + result_map[f'{etype_str}.num_sampled_edges'] = \ + torch.tensor(output.num_sampled_edges[etype], device=self.device) # Collect node labels of input node type. input_type = output.input_type assert input_type is not None @@ -669,8 +673,11 @@ async def _colloate_fn( result_map['ids'] = output.node result_map['rows'] = output.row result_map['cols'] = output.col - result_map['num_sampled_nodes'] = output.num_sampled_nodes - result_map['num_sampled_edges'] = output.num_sampled_edges + if output.num_sampled_nodes is not None: + result_map['num_sampled_nodes'] = \ + torch.tensor(output.num_sampled_nodes, device=self.device) + result_map['num_sampled_edges'] = \ + torch.tensor(output.num_sampled_edges, device=self.device) if self.with_edge: result_map['eids'] = output.edge # Collect node labels. diff --git a/graphlearn_torch/python/loader/transform.py b/graphlearn_torch/python/loader/transform.py index a7b0f31b..fc4d2f56 100644 --- a/graphlearn_torch/python/loader/transform.py +++ b/graphlearn_torch/python/loader/transform.py @@ -16,6 +16,7 @@ from typing import Dict, Optional import torch +import torch.nn.functional as F from torch_geometric.data import Data, HeteroData from ..sampler import SamplerOutput, HeteroSamplerOutput @@ -75,10 +76,13 @@ def to_hetero_data( if edge_feat_dict is not None: data[k].edge_attr = edge_feat_dict.get(k, None) if k not in hetero_sampler_out.num_sampled_edges: - hetero_sampler_out.num_sampled_edges[k] = [0] * num_hops + hetero_sampler_out.num_sampled_edges[k] = \ + torch.tensor([0] * num_hops, device=data[k].edge_index.device) else: - hetero_sampler_out.num_sampled_edges[k] += \ - [0] * (num_hops - len(hetero_sampler_out.num_sampled_edges[k])) + hetero_sampler_out.num_sampled_edges[k] = F.pad( + hetero_sampler_out.num_sampled_edges[k], + (0, num_hops - hetero_sampler_out.num_sampled_edges[k].size(0)) + ) # nodes for k, v in hetero_sampler_out.node.items(): @@ -86,10 +90,13 @@ def to_hetero_data( if node_feat_dict is not None: data[k].x = node_feat_dict.get(k, None) if k not in hetero_sampler_out.num_sampled_nodes: - hetero_sampler_out.num_sampled_nodes[k] = [0] * (num_hops + 1) + hetero_sampler_out.num_sampled_nodes[k] = \ + torch.tensor([0] * (num_hops + 1), device=data[k].node.device) else: - hetero_sampler_out.num_sampled_nodes[k] += \ - [0] * (num_hops + 1 - len(hetero_sampler_out.num_sampled_nodes[k])) + hetero_sampler_out.num_sampled_nodes[k] = F.pad( + hetero_sampler_out.num_sampled_nodes[k], + (0, num_hops + 1 - hetero_sampler_out.num_sampled_nodes[k].size(0)) + ) # seed nodes for k, v in hetero_sampler_out.batch.items(): diff --git a/graphlearn_torch/python/sampler/base.py b/graphlearn_torch/python/sampler/base.py index 980cbeb3..507291d2 100644 --- a/graphlearn_torch/python/sampler/base.py +++ b/graphlearn_torch/python/sampler/base.py @@ -235,8 +235,8 @@ class SamplerOutput(CastMixin): col: torch.Tensor edge: Optional[torch.Tensor] = None batch: Optional[torch.Tensor] = None - num_sampled_nodes: Optional[List[int]] = None - num_sampled_edges: Optional[List[int]] = None + num_sampled_nodes: Optional[Union[List[int], torch.Tensor]] = None + num_sampled_edges: Optional[Union[List[int], torch.Tensor]] = None device: Optional[torch.device] = None metadata: Optional[Any] = None @@ -284,8 +284,8 @@ class HeteroSamplerOutput(CastMixin): col: Dict[EdgeType, torch.Tensor] edge: Optional[Dict[EdgeType, torch.Tensor]] = None batch: Optional[Dict[NodeType, torch.Tensor]] = None - num_sampled_nodes: Optional[Dict[NodeType, List[int]]] = None - num_sampled_edges: Optional[Dict[EdgeType, List[int]]] = None + num_sampled_nodes: Optional[Dict[NodeType, Union[List[int], torch.Tensor]]] = None + num_sampled_edges: Optional[Dict[EdgeType, Union[List[int], torch.Tensor]]] = None edge_types: Optional[List[EdgeType]] = None input_type: Optional[Union[NodeType, EdgeType]] = None device: Optional[torch.device] = None diff --git a/graphlearn_torch/python/sampler/neighbor_sampler.py b/graphlearn_torch/python/sampler/neighbor_sampler.py index 1021cad8..613e28fb 100644 --- a/graphlearn_torch/python/sampler/neighbor_sampler.py +++ b/graphlearn_torch/python/sampler/neighbor_sampler.py @@ -258,9 +258,11 @@ def _hetero_sample_from_nodes( col=res_cols, edge=(res_edges if len(res_edges) else None), batch=batch, - num_sampled_nodes=num_sampled_nodes, + num_sampled_nodes={k : torch.tensor(v, device=self.device) + for k, v in num_sampled_nodes.items()}, num_sampled_edges={ - reverse_edge_type(k) : v for k, v in num_sampled_edges.items()}, + reverse_edge_type(k) : torch.tensor(v, device=self.device) + for k, v in num_sampled_edges.items()}, edge_types=self.edge_types, device=self.device ) diff --git a/test/python/test_dist_neighbor_loader.py b/test/python/test_dist_neighbor_loader.py index 1c80dcfe..b8f755fb 100644 --- a/test/python/test_dist_neighbor_loader.py +++ b/test/python/test_dist_neighbor_loader.py @@ -49,6 +49,13 @@ def _check_sample_result(data): int(rows[i]) == ((int(cols[i]) + 2) % vnum_total) ) + tc.assertEqual(data.num_sampled_nodes[0].item(), 5) + tc.assertEqual(data.num_sampled_nodes.size(0), 3) + tc.assertNotEqual(data.num_sampled_nodes[1].item(), 0) + tc.assertNotEqual(data.num_sampled_nodes[2].item(), 0) + tc.assertEqual(data.num_sampled_edges[0].item(), 10) + tc.assertEqual(data.num_sampled_edges.size(0), 2) + tc.assertNotEqual(data.num_sampled_edges[1].item(), 0) def _check_hetero_sample_result(data): tc = unittest.TestCase() @@ -107,6 +114,16 @@ def _check_hetero_sample_result(data): int(rev_i2i_rows[i]) == ((int(rev_i2i_cols[i]) + 3) % vnum_total) ) + tc.assertEqual(data.num_sampled_nodes['item'][0].item(), 0) + tc.assertNotEqual(data.num_sampled_nodes['item'][1].item(), 0) + tc.assertNotEqual(data.num_sampled_nodes['item'][2].item(), 0) + tc.assertEqual(data.num_sampled_nodes['user'][0].item(), 5) + tc.assertEqual(data.num_sampled_nodes['user'][1].item(), 0) + tc.assertEqual(data.num_sampled_nodes['user'][2].item(), 0) + tc.assertEqual(data.num_sampled_edges['item', 'rev_u2i', 'user'][0].item(), 10) + tc.assertEqual(data.num_sampled_edges['item', 'rev_u2i', 'user'][1].item(), 0) + tc.assertEqual(data.num_sampled_edges['item', 'i2i', 'item'][0].item(), 0) + tc.assertNotEqual(data.num_sampled_edges['item', 'i2i', 'item'][1].item(), 0) def run_test_as_worker(world_size: int, rank: int, master_port: int, sampling_master_port: int,