From 2a85f9300733c09ec90819bc6df4bff8f103fd67 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 1 May 2024 21:28:21 -0700 Subject: [PATCH] [Core][Distributed] enable multiple tp group (#4512) Co-authored-by: Zhuohan Li --- .buildkite/test-pipeline.yaml | 11 ++++++-- .buildkite/test-template.j2 | 3 ++ tests/distributed/test_pynccl.py | 28 +++++++++++++++++++ .../device_communicators/pynccl.py | 5 +++- 4 files changed, 43 insertions(+), 4 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 641f366d06031..d518fb9ccecfa 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -25,19 +25,24 @@ steps: - label: Distributed Comm Ops Test command: pytest -v -s test_comm_ops.py working_dir: "/vllm-workspace/tests/distributed" - num_gpus: 2 # only support 1 or 2 for now. + num_gpus: 2 - label: Distributed Tests working_dir: "/vllm-workspace/tests/distributed" - num_gpus: 2 # only support 1 or 2 for now. + num_gpus: 2 commands: - - pytest -v -s test_pynccl.py - pytest -v -s test_pynccl_library.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py +- label: Distributed Tests (Multiple Groups) + working_dir: "/vllm-workspace/tests/distributed" + num_gpus: 4 + commands: + - pytest -v -s test_pynccl.py + - label: Engine Test command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 5c9515840bb03..2cb21cacd065b 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -45,6 +45,9 @@ steps: plugins: - kubernetes: podSpec: + {% if step.num_gpus %} + priorityClassName: gpu-priority-cls-{{ step.num_gpus }} + {% endif %} volumes: - name: dshm emptyDir: diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 6d7d4a5806bd0..e71d839648c83 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -58,6 +58,34 @@ def test_pynccl(): distributed_run(worker_fn, 2) +@worker_fn_wrapper +def multiple_tp_worker_fn(): + device = torch.device(f"cuda:{torch.distributed.get_rank()}") + groups = [ + torch.distributed.new_group(ranks=[0, 1], backend="gloo"), + torch.distributed.new_group(ranks=[2, 3], backend="gloo") + ] + group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] + comm = NCCLCommunicator(group=group, device=device) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) + # two groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + comm.all_reduce(tensor) + comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 4 + else: + comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_multiple_tp(): + distributed_run(worker_fn, 4) + + @worker_fn_wrapper def worker_fn_with_cudagraph(): with torch.no_grad(): diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index f21fcd262d810..758994352e3de 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -232,6 +232,7 @@ def __init__( assert dist.get_backend(group) != dist.Backend.NCCL, ( "NCCLCommunicator should be attached to a non-NCCL group.") self.group = group + # note: this rank is the rank in the group self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) if self.rank == 0: @@ -239,7 +240,9 @@ def __init__( else: self.unique_id = NcclUniqueId() tensor = torch.ByteTensor(list(self.unique_id.internal)) - dist.broadcast(tensor, src=0, group=group) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte