diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index a8571a1157892..a77b48d5e49f3 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -59,7 +59,8 @@ def worker_fn(): device=get_world_group().device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) - tensor = pynccl_comm.all_reduce(tensor) + with pynccl_comm.change_state(enable=True): + tensor = pynccl_comm.all_reduce(tensor) torch.cuda.synchronize() assert torch.all(tensor == pynccl_comm.world_size).cpu().item() @@ -80,16 +81,17 @@ def multiple_allreduce_worker_fn(): group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - # two groups can communicate independently - if torch.distributed.get_rank() in [0, 1]: - tensor = pynccl_comm.all_reduce(tensor) - tensor = pynccl_comm.all_reduce(tensor) - torch.cuda.synchronize() - assert torch.all(tensor == 4).cpu().item() - else: - tensor = pynccl_comm.all_reduce(tensor) - torch.cuda.synchronize() - assert torch.all(tensor == 2).cpu().item() + with pynccl_comm.change_state(enable=True): + # two groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + tensor = pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() + assert torch.all(tensor == 4).cpu().item() + else: + tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() + assert torch.all(tensor == 2).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 4, @@ -135,7 +137,8 @@ def worker_fn_with_cudagraph(): # run something in the default stream to initialize torch engine a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph): + with torch.cuda.graph(graph), \ + pynccl_comm.change_state(enable=True): a_out = pynccl_comm.all_reduce(a) torch.cuda.synchronize() graph.replay() @@ -164,7 +167,8 @@ def all_gather_worker_fn(): for r in range(world_size) ]).to(device) - pynccl_comm.all_gather(result, tensor) + with pynccl_comm.change_state(enable=True): + pynccl_comm.all_gather(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -201,7 +205,8 @@ def reduce_scatter_worker_fn(): expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size] for tensor in all_tensors).to(device) - pynccl_comm.reduce_scatter(result, tensor) + with pynccl_comm.change_state(enable=True): + pynccl_comm.reduce_scatter(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -228,13 +233,15 @@ def send_recv_worker_fn(): else: tensor = torch.empty(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) - - if pynccl_comm.rank == 0: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) - else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + with pynccl_comm.change_state(enable=True): + if pynccl_comm.rank == 0: + pynccl_comm.send(tensor, + dst=(pynccl_comm.rank + 1) % + pynccl_comm.world_size) + else: + pynccl_comm.recv(tensor, + src=(pynccl_comm.rank - 1) % + pynccl_comm.world_size) torch.cuda.synchronize() assert torch.all(tensor == 1).cpu().item() @@ -265,12 +272,15 @@ def multiple_send_recv_worker_fn(): 1024, dtype=torch.float32, device=device) - if torch.distributed.get_rank() in [0, 1]: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) - else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + with pynccl_comm.change_state(enable=True): + if torch.distributed.get_rank() in [0, 1]: + pynccl_comm.send(tensor, + dst=(pynccl_comm.rank + 1) % + pynccl_comm.world_size) + else: + pynccl_comm.recv(tensor, + src=(pynccl_comm.rank - 1) % + pynccl_comm.world_size) torch.cuda.synchronize() if torch.distributed.get_rank() in [0, 2]: assert torch.all(tensor == 1).cpu().item() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index fda4d007ceb5b..93d96fd8f5686 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import Optional, Union # ===================== import region ===================== @@ -212,3 +213,19 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) + + @contextmanager + def change_state(self, enable: Optional[bool] = None): + """ + A context manager to change the state of the communicator. + """ + if enable is None: + # guess a default value when not specified + enable = self.available + + old_disable = self.disabled + + self.disabled = not enable + yield + + self.disabled = old_disable diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a837c1dc5953b..dccd3addbcb35 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -305,7 +305,14 @@ def graph_capture( stream.wait_stream(curr_stream) with torch.cuda.stream(stream), maybe_ca_context: - yield graph_capture_context + pynccl_comm = self.pynccl_comm + maybe_pynccl_context: Any + if not pynccl_comm: + maybe_pynccl_context = nullcontext() + else: + maybe_pynccl_context = pynccl_comm.change_state() + with maybe_pynccl_context: + yield graph_capture_context def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: """