Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
LiSu committed Dec 6, 2023
1 parent 4f82656 commit 0e8b8eb
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 24 deletions.
28 changes: 17 additions & 11 deletions examples/igbh/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,21 @@ nto an undirected graph.
python train_rgnn.py --model='rgat' --dataset_size='tiny' --num_classes=19
```
The script uses a single GPU, please add `--cpu_mode` if you want to use CPU only.
To save the memory costs while training large datasets, add `--use_fp16` to store
To save the memory costs while training large datasets, add `--use_fp16` to load
feature data in FP16 format. Option `--pin_feature` decides if the feature data will be
pinned in host memory, which enables zero-copy feature access from GPU but will
incur extra memory costs.

To train the model using multiple GPUs using FP16 format wihtout pinning the feature:
To train the model using multiple GPUs and FP16 format wihtout pinning the feature:
```
CUDA_VISIBLE_DEVICES=0,1 python train_rgnn_multi_gpu.py --model='rgat' --dataset_size='tiny' --num_classes=19 --use_fp16
```

Note that the original graph is in COO fornat, the above scripts will transform
the graph from COO to CSC or CSR according to the edge direction of sampling.
If `--use_fp16` is enabled, the feature will be converted from `fp32`into `fp16`.
This process could be time consuming. We provide a script to convert and persist
the graph layout (from `COO` to `CSC` or `CSR`) and the data type of feature:
Note that the original graph is in `COO` format, the above scripts will transform
the graph from `COO` to `CSC` or `CSR` according to the edge direction of sampling.
This process could be time consuming. We provide a script to convert the graph layout
from `COO` to `CSC/CSR` and persist the feature in FP16 format:

```
python compress_graph.py --dataset_size='tiny' --layout='CSC' --use_fp16
```
Expand Down Expand Up @@ -77,13 +77,17 @@ python partition.py --dataset_size='tiny' --num_partitions=2 --num_classes=19
GLT also supports two-stage partitioning, which splits the process of topology
partitioning and feature partitioning. After the topology partitioning is executed,
the feature partitioning process can be conducted in each training node in parallel
to speedup the partitioning.
to speedup the partitioning process.

The topology partitioning is conducted by setting `--with_feature=0`:
```
python partition.py --dataset_size='tiny' --num_partitions=2 --num_classes=19 --with_feature=0
```

By default the layout of partitioned graph is in the `COO` format, `CSC` and `CSR` are also
supported by setting `--layout` for `partition.py`.


The feature partitioning in conducted in each training node:
```
# node 0 which holds partition 0:
Expand All @@ -92,6 +96,8 @@ python build_partition_feature.py --dataset_size='tiny' --in_memory=0 --partitio
# node 1 which holds partition 1:
python build_partition_feature.py --dataset_size='tiny' --in_memory=0 --partition_idx=1
```
Building partition feature with `--use_fp16` will convert the data type of feature
from FP32 into FP16.

### 3.2 Example of distributed training
2 nodes each with 2 GPUs
Expand All @@ -108,12 +114,12 @@ To seperate the GPU used by sampling and training processes, please add `--split

```
# node 0:
CUDA_VISIBLE_DEVICES=0,1 python dist_train_rgnn.py --num_nodes=2 --node_rank=0 --num_training_procs=1 --master_addr=localhost --model='rgat' --dataset_size='tiny' --num_classes=19
CUDA_VISIBLE_DEVICES=0,1 python dist_train_rgnn.py --num_nodes=2 --node_rank=0 --num_training_procs=1 --master_addr=localhost --model='rgat' --dataset_size='tiny' --num_classes=19 --split_training_sampling
# node 1:
CUDA_VISIBLE_DEVICES=2,3 python dist_train_rgnn.py --num_nodes=2 --node_rank=1 --num_training_procs=1 --master_addr=localhost --model='rgat' --dataset_size='tiny' --num_classes=19
CUDA_VISIBLE_DEVICES=2,3 python dist_train_rgnn.py --num_nodes=2 --node_rank=1 --num_training_procs=1 --master_addr=localhost --model='rgat' --dataset_size='tiny' --num_classes=19 --split_training_sampling
```
The script uses one GPU for training and another GPU for sampling in each node.
The script uses one GPU for training and another GPU for sampling in each node.

Note:
- The `num_partitions` and `num_nodes` must be the same.
Expand Down
2 changes: 1 addition & 1 deletion examples/igbh/compress_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def process(self):

for etype in self.etypes:
graph = glt_dataset.get_graph(etype)
indptr, indices = graph.export_topology()
indptr, indices, _ = graph.export_topology()
path = os.path.join(self.dir, self.dataset_size, 'processed', self.layout, compress_edge_dict[etype])
if not os.path.exists(path):
os.makedirs(path)
Expand Down
13 changes: 7 additions & 6 deletions examples/igbh/dist_train_rgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,14 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,

current_ctx = glt.distributed.get_context()
if with_gpu:
current_device = torch.device((local_proc_rank * 2) % torch.cuda.device_count())
if split_training_sampling:
current_device = torch.device((local_proc_rank * 2) % torch.cuda.device_count())
sampling_device = torch.device((local_proc_rank * 2 + 1) % torch.cuda.device_count())
else:
current_device = torch.device(local_proc_rank % torch.cuda.device_count())
sampling_device = current_device
else:
current_device = torch.device('cpu')

if split_training_sampling:
sampling_device = torch.device((local_proc_rank * 2 + 1) % torch.cuda.device_count())
else:
sampling_device = current_device

# Initialize training process group of PyTorch.
Expand Down Expand Up @@ -278,7 +279,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
parser.add_argument('--layout', type=str, default='COO',
help="Layout of input graph: CSC, CSR, COO. Default is COO.")
parser.add_argument("--rpc_timeout", type=int, default=180,
help="rpc timeout in seconds")
help="rpc timeout in seconds")
parser.add_argument("--split_training_sampling", action="store_true",
help="Use seperate GPUs for training and sampling processes.")
parser.add_argument("--with_trim", action="store_true",
Expand Down
8 changes: 3 additions & 5 deletions examples/igbh/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from dataset import IGBHeteroDataset
from typing import Literal


def partition_dataset(src_path: str,
dst_path: str,
num_partitions: int,
Expand All @@ -33,7 +32,7 @@ def partition_dataset(src_path: str,
use_label_2K: bool=False,
with_feature: bool=True,
use_fp16: bool=False,
layout: Literal['CSC', 'CSR', 'COO'] = 'CSC'):
layout: Literal['CSC', 'CSR', 'COO'] = 'COO'):
print(f'-- Loading igbh_{dataset_size} ...')
data = IGBHeteroDataset(src_path, dataset_size, in_memory, use_label_2K, use_fp16=use_fp16)
node_num = {k : v.shape[0] for k, v in data.feat_dict.items()}
Expand Down Expand Up @@ -117,7 +116,7 @@ def partition_dataset(src_path: str,

for etype in graph_dict:
graph = dataset.get_graph(etype)
indptr, indices = graph.export_topology()
indptr, indices, _ = graph.export_topology()
path = osp.join(base_path, compress_edge_dict[etype])
if layout == 'CSR':
torch.save(indptr, osp.join(path, 'rows.pt'))
Expand All @@ -126,7 +125,6 @@ def partition_dataset(src_path: str,
torch.save(indptr, osp.join(path, 'cols.pt'))
torch.save(indices, osp.join(path, 'rows.pt'))


if __name__ == '__main__':
root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh')
glt.utils.ensure_dir(root)
Expand All @@ -152,7 +150,7 @@ def partition_dataset(src_path: str,
choices=[0, 1], help='0:do not partition feature, 1:partition feature')
parser.add_argument('--use_fp16', action="store_true",
help="save partitioned node/edge feature into fp16 format")
parser.add_argument("--layout", type=str, default='CSC',
parser.add_argument("--layout", type=str, default='COO',
help="layout of the partitioned graph: CSC, CSR, COO")

args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion graphlearn_torch/python/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def lazy_init(self):
f"invalid mode {self.mode}")

def export_topology(self):
return self.topo.indptr, self.topo.indices
return self.topo.indptr, self.topo.indices, self.topo.edge_ids

def share_ipc(self):
r""" Create ipc handle for multiprocessing.
Expand Down

0 comments on commit 0e8b8eb

Please sign in to comment.