Skip to content

Commit bebf64f

Browse files
authored
[Feat] range partition book (#146)
1 parent 36ce42b commit bebf64f

File tree

7 files changed

+219
-58
lines changed

7 files changed

+219
-58
lines changed

graphlearn_torch/python/partition/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515

1616
from .base import *
1717
from .frequency_partitioner import FrequencyPartitioner
18+
from .partition_book import *
1819
from .random_partitioner import RandomPartitioner

graphlearn_torch/python/partition/base.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,6 @@ def __getitem__(self, indices):
3636
def offset(self):
3737
return 0
3838

39-
class GLTPartitionBook(PartitionBook, torch.Tensor):
40-
r""" A partition book of graph nodes or edges.
41-
"""
42-
def __getitem__(self, indices) -> torch.Tensor:
43-
return torch.Tensor.__getitem__(self, indices)
44-
4539
HeteroNodePartitionDict = Dict[NodeType, PartitionBook]
4640
HeteroEdgePartitionDict = Dict[EdgeType, PartitionBook]
4741

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
from typing import List, Tuple
3+
from .base import PartitionBook
4+
5+
6+
class RangePartitionBook(PartitionBook):
7+
r"""A class for managing range-based partitions of consecutive IDs.
8+
Suitable when IDs within each partition are consecutive.
9+
Args:
10+
partition_ranges (List[Tuple[int, int]]): A list of tuples representing
11+
the start and end (exclusive) of each partition range.
12+
partition_idx (int): The index of the current partition.
13+
Example:
14+
>>> partition_ranges = [(0, 10), (10, 20), (20, 30)]
15+
>>> range_pb = RangePartitionBook(partition_ranges, partition_idx=1)
16+
>>> indices = torch.tensor([0, 5, 10, 15, 20, 25])
17+
>>> partition_ids = range_pb[indices]
18+
>>> print(partition_ids)
19+
tensor([0, 0, 1, 1, 2, 2])
20+
"""
21+
22+
def __init__(self, partition_ranges: List[Tuple[int, int]], partition_idx: int):
23+
if not all(r[0] < r[1] for r in partition_ranges):
24+
raise ValueError("All partition ranges must have start < end")
25+
if not all(r1[1] == r2[0] for r1, r2 in zip(partition_ranges[:-1], partition_ranges[1:])):
26+
raise ValueError("Partition ranges must be continuous")
27+
28+
self.partition_bounds = torch.tensor(
29+
[end for _, end in partition_ranges], dtype=torch.long)
30+
self.partition_idx = partition_idx
31+
self._id2index = OffsetId2Index(partition_ranges[partition_idx][0])
32+
33+
def __getitem__(self, indices: torch.Tensor) -> torch.Tensor:
34+
return torch.searchsorted(self.partition_bounds, indices, right=True)
35+
36+
@property
37+
def device(self):
38+
return self.partition_bounds.device
39+
40+
@property
41+
def id2index(self):
42+
return self._id2index
43+
44+
def id_filter(self, node_pb: PartitionBook, partition_idx: int):
45+
start = self.partition_bounds[partition_idx-1] if partition_idx > 0 else 0
46+
end = self.partition_bounds[partition_idx]
47+
return torch.arange(start, end)
48+
49+
50+
class OffsetId2Index:
51+
r"""
52+
Convert global IDs to local indices by subtracting a specified offset.
53+
"""
54+
55+
def __init__(self, offset: int):
56+
self.offset = offset
57+
58+
def __getitem__(self, ids: torch.Tensor) -> torch.Tensor:
59+
local_indices = ids - self.offset
60+
return local_indices
61+
62+
def to(self, device):
63+
# device is always same as the input ids
64+
return self
65+
66+
67+
class GLTPartitionBook(PartitionBook, torch.Tensor):
68+
r""" A partition book of graph nodes or edges.
69+
"""
70+
71+
def __getitem__(self, indices) -> torch.Tensor:
72+
return torch.Tensor.__getitem__(self, indices)

test/python/dist_test_utils.py

Lines changed: 74 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020

2121
# options for dataset generation
2222
vnum_per_partition = 20
23-
vnum_total = vnum_per_partition * 2
23+
num_partition = 2
24+
vnum_total = vnum_per_partition * num_partition # 40
2425
degree = 2
25-
enum_total = vnum_total * degree
26+
enum_per_partition = vnum_per_partition * degree # 40
27+
enum_total = enum_per_partition * num_partition # 80
2628

2729
# for hetero dataset
2830
user_ntype = 'user'
@@ -36,64 +38,106 @@
3638
device_num = 2
3739

3840

39-
def _prepare_dataset(rank: int, weighted: bool = False):
40-
# partition
41-
node_pb = torch.tensor(
42-
[v % 2 for v in range(0, vnum_total)],
43-
dtype=torch.long
44-
)
45-
edge_pb = torch.tensor(
46-
[((e // degree) % 2) for e in range(0, enum_total)],
47-
dtype=torch.long
48-
)
41+
def _prepare_dataset(rank: int,
42+
weighted: bool = False,
43+
is_range_partition: bool = False):
44+
"""
45+
Prepare a synthetic graph dataset with 40 nodes and 80 edges for unit tests.
46+
47+
Graph topology:
48+
- rows: [0, 0, 1, 1, 2, 2, ... 37, 37, 38, 38, 39, 39]
49+
- cols: [1, 2, 2, 3, 3, 4, ... 38, 39, 39, 0, 0, 1]
50+
- eids: [0, 1, 2, 3, 4, 5, ... 74, 75, 76, 77, 78, 79]
51+
52+
Node features:
53+
[[0., 0., ..., 0., 0.],
54+
[1., 1., ..., 1., 1.],
55+
...
56+
[39., 39., ..., 39., 39.]]
57+
58+
Edge features:
59+
[[0., 0., ..., 0., 0.],
60+
[1., 1., ..., 1., 1.],
61+
...
62+
[79., 79., ..., 79., 79.]]
63+
64+
Two partition strategies are available:
65+
1. Range partition:
66+
- Nodes with IDs [0, 19] and edges with IDs [0, 39] are on partition 0
67+
- Nodes with IDs [20, 39] and edges with IDs [40, 79] are on partition 1
68+
2. Hash partition:
69+
- Even-numbered nodes and edges are on partition 0
70+
- Odd-numbered nodes and edges are on partition 1
71+
72+
The graph topology and features are identical under both partition strategies.
73+
"""
74+
if is_range_partition:
75+
node_ranges = [(0, vnum_per_partition), (vnum_per_partition, vnum_total)]
76+
edge_ranges = [(0, enum_total // 2), (enum_total // 2, enum_total)]
77+
node_pb = glt.partition.RangePartitionBook(
78+
node_ranges, rank)
79+
edge_pb = glt.partition.RangePartitionBook(
80+
edge_ranges, rank)
81+
start, end, step = rank * vnum_per_partition, (rank + 1) * vnum_per_partition, 1
82+
else:
83+
node_pb = torch.tensor(
84+
[v % 2 for v in range(0, vnum_total)],
85+
dtype=torch.long
86+
)
87+
edge_pb = torch.tensor(
88+
[((e // degree) % 2) for e in range(0, enum_total)],
89+
dtype=torch.long
90+
)
91+
start, end, step = rank, vnum_total, 2
92+
4993

5094
# graph
5195
nodes, rows, cols, eids = [], [], [], []
52-
for v in range(rank, vnum_total, 2):
96+
for v in range(start, end, step):
5397
nodes.append(v)
5498
rows.extend([v for _ in range(degree)])
5599
cols.extend([((v + i + 1) % vnum_total) for i in range(degree)])
56100
eids.extend([(v * degree + i) for i in range(degree)])
101+
57102
edge_index = torch.tensor([rows, cols], dtype=torch.int64)
58103
edge_ids = torch.tensor(eids, dtype=torch.int64)
59104
edge_weights = (edge_ids % 2).to(torch.float)
60105
csr_topo = glt.data.Topology(edge_index=edge_index, edge_ids=edge_ids)
106+
graph = glt.data.Graph(csr_topo, 'ZERO_COPY', device=0)
107+
61108
weighted_csr_topo = glt.data.Topology(
62109
edge_index=edge_index, edge_ids=edge_ids, edge_weights=edge_weights)
63-
graph = glt.data.Graph(csr_topo, 'ZERO_COPY', device=0)
64110
weighted_graph = glt.data.Graph(weighted_csr_topo, 'CPU')
65111

66112
# feature
67113
device_group_list = [glt.data.DeviceGroup(0, [0]),
68114
glt.data.DeviceGroup(1, [1])]
69115
split_ratio = 0.2
70116

71-
nfeat = rank + torch.zeros(len(nodes), 512, dtype=torch.float32)
72-
nfeat_id2idx = glt.utils.id2idx(nodes)
117+
nfeat = torch.tensor(nodes, dtype=torch.float32).unsqueeze(1).repeat(1, 512)
118+
nfeat_id2idx = node_pb.id2index if is_range_partition else glt.utils.id2idx(nodes)
73119
node_feature = glt.data.Feature(nfeat, nfeat_id2idx, split_ratio,
74120
device_group_list, device=0)
75121

76-
efeat = rank + torch.ones(len(eids), 10, dtype=torch.float32)
77-
efeat_id2idx = glt.utils.id2idx(eids)
122+
efeat = torch.tensor(eids, dtype=torch.float32).unsqueeze(1).repeat(1, 10)
123+
efeat_id2idx = edge_pb.id2index if is_range_partition else glt.utils.id2idx(eids)
78124
edge_feature = glt.data.Feature(efeat, efeat_id2idx, split_ratio,
79125
device_group_list, device=0)
80126

81127
# whole node label
82128
node_label = torch.arange(vnum_total)
83129

84130
# dist dataset
85-
if weighted:
86-
return glt.distributed.DistDataset(
87-
2, rank,
88-
weighted_graph, node_feature, edge_feature, node_label,
89-
node_pb, edge_pb
90-
)
91-
else:
92-
return glt.distributed.DistDataset(
93-
2, rank,
94-
graph, node_feature, edge_feature, node_label,
95-
node_pb, edge_pb
96-
)
131+
ds = glt.distributed.DistDataset(
132+
2, rank,
133+
weighted_graph if weighted else graph,
134+
node_feature, edge_feature, node_label,
135+
node_pb, edge_pb
136+
)
137+
138+
if is_range_partition:
139+
ds.id_filter = node_pb.id_filter
140+
return ds
97141

98142

99143
def _prepare_hetero_dataset(

test/python/test_dist_link_loader.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from dist_test_utils import *
2323
from dist_test_utils import _prepare_dataset, _prepare_hetero_dataset
24+
from parameterized import parameterized
2425

2526
def _check_sample_result(data, edge_dir='out'):
2627
tc = unittest.TestCase()
@@ -338,6 +339,8 @@ class DistLinkNeighborLoaderTestCase(unittest.TestCase):
338339
def setUp(self):
339340
self.dataset0 = _prepare_dataset(rank=0)
340341
self.dataset1 = _prepare_dataset(rank=1)
342+
self.range_partition_dataset0 = _prepare_dataset(rank=0, is_range_partition=True)
343+
self.range_partition_dataset1 = _prepare_dataset(rank=1, is_range_partition=True)
341344
self.input_edges0 = torch.stack(
342345
(torch.arange(vnum_per_partition), torch.arange(vnum_per_partition)+1)
343346
).to(dtype=torch.long)
@@ -357,37 +360,52 @@ def setUp(self):
357360
self.master_port = glt.utils.get_free_port()
358361
self.sampling_master_port = glt.utils.get_free_port()
359362

360-
def test_homo_out_sample_collocated(self):
363+
def _get_homo_datasets(self, is_range_partition):
364+
return (self.range_partition_dataset0, self.range_partition_dataset1) if is_range_partition else (self.dataset0, self.dataset1)
365+
366+
@parameterized.expand([
367+
(True),
368+
(False),
369+
])
370+
def test_homo_out_sample_collocated(self, is_range_partition):
361371
print("\n--- DistLinkNeighborLoader Test (homogeneous, collocated) ---")
372+
dataset0, dataset1 = self._get_homo_datasets(is_range_partition)
373+
362374
mp_context = torch.multiprocessing.get_context('spawn')
363375
w0 = mp_context.Process(
364376
target=run_test_as_worker,
365377
args=(2, 0, self.master_port, self.sampling_master_port,
366-
self.dataset0, self.bin_neg_sampling, self.input_edges0, _check_sample_result, True)
378+
dataset0, self.bin_neg_sampling, self.input_edges0, _check_sample_result, True)
367379
)
368380
w1 = mp_context.Process(
369381
target=run_test_as_worker,
370382
args=(2, 1, self.master_port, self.sampling_master_port,
371-
self.dataset1, self.bin_neg_sampling, self.input_edges1, _check_sample_result, True)
383+
dataset1, self.bin_neg_sampling, self.input_edges1, _check_sample_result, True)
372384
)
373385
w0.start()
374386
w1.start()
375387
w0.join()
376388
w1.join()
377-
378-
def test_homo_out_sample_mp(self):
389+
390+
@parameterized.expand([
391+
(True),
392+
(False),
393+
])
394+
def test_homo_out_sample_mp(self, is_range_partition):
379395
print("\n--- DistLinkNeighborLoader Test (homogeneous, multiprocessing) ---")
396+
dataset0, dataset1 = self._get_homo_datasets(is_range_partition)
397+
380398
mp_context = torch.multiprocessing.get_context('spawn')
381399
w0 = mp_context.Process(
382400
target=run_test_as_worker,
383401
args=(2, 0, self.master_port, self.sampling_master_port,
384-
self.dataset0, self.tri_neg_sampling, self.input_edges0,
402+
dataset0, self.tri_neg_sampling, self.input_edges0,
385403
_check_sample_result, False)
386404
)
387405
w1 = mp_context.Process(
388406
target=run_test_as_worker,
389407
args=(2, 1, self.master_port, self.sampling_master_port,
390-
self.dataset1, self.tri_neg_sampling, self.input_edges1,
408+
dataset1, self.tri_neg_sampling, self.input_edges1,
391409
_check_sample_result, False)
392410
)
393411
w0.start()

0 commit comments

Comments
 (0)