Skip to content

Commit

Permalink
Revert "[distributed] remove pynccl's redundant change_state (vllm-pr…
Browse files Browse the repository at this point in the history
…oject#11749)"

This reverts commit 9e764e7.
  • Loading branch information
gshtras committed Jan 8, 2025
1 parent 88e020d commit 9be84c0
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 28 deletions.
64 changes: 37 additions & 27 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def worker_fn():
device=get_world_group().device)
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
tensor = pynccl_comm.all_reduce(tensor)
with pynccl_comm.change_state(enable=True):
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()

Expand All @@ -80,16 +81,17 @@ def multiple_allreduce_worker_fn():
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
assert torch.all(tensor == 4).cpu().item()
else:
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
assert torch.all(tensor == 2).cpu().item()
with pynccl_comm.change_state(enable=True):
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
assert torch.all(tensor == 4).cpu().item()
else:
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize()
assert torch.all(tensor == 2).cpu().item()


@pytest.mark.skipif(torch.cuda.device_count() < 4,
Expand Down Expand Up @@ -135,7 +137,8 @@ def worker_fn_with_cudagraph():
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
torch.cuda.synchronize()
with torch.cuda.graph(graph):
with torch.cuda.graph(graph), \
pynccl_comm.change_state(enable=True):
a_out = pynccl_comm.all_reduce(a)
torch.cuda.synchronize()
graph.replay()
Expand Down Expand Up @@ -164,7 +167,8 @@ def all_gather_worker_fn():
for r in range(world_size)
]).to(device)

pynccl_comm.all_gather(result, tensor)
with pynccl_comm.change_state(enable=True):
pynccl_comm.all_gather(result, tensor)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)

Expand Down Expand Up @@ -201,7 +205,8 @@ def reduce_scatter_worker_fn():
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
for tensor in all_tensors).to(device)

pynccl_comm.reduce_scatter(result, tensor)
with pynccl_comm.change_state(enable=True):
pynccl_comm.reduce_scatter(result, tensor)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)

Expand All @@ -228,13 +233,15 @@ def send_recv_worker_fn():
else:
tensor = torch.empty(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)

if pynccl_comm.rank == 0:
pynccl_comm.send(tensor,
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
with pynccl_comm.change_state(enable=True):
if pynccl_comm.rank == 0:
pynccl_comm.send(tensor,
dst=(pynccl_comm.rank + 1) %
pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
torch.cuda.synchronize()
assert torch.all(tensor == 1).cpu().item()

Expand Down Expand Up @@ -265,12 +272,15 @@ def multiple_send_recv_worker_fn():
1024,
dtype=torch.float32,
device=device)
if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.send(tensor,
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
with pynccl_comm.change_state(enable=True):
if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.send(tensor,
dst=(pynccl_comm.rank + 1) %
pynccl_comm.world_size)
else:
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
torch.cuda.synchronize()
if torch.distributed.get_rank() in [0, 2]:
assert torch.all(tensor == 1).cpu().item()
Expand Down
17 changes: 17 additions & 0 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import contextmanager
from typing import Optional, Union

# ===================== import region =====================
Expand Down Expand Up @@ -212,3 +213,19 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))

@contextmanager
def change_state(self, enable: Optional[bool] = None):
"""
A context manager to change the state of the communicator.
"""
if enable is None:
# guess a default value when not specified
enable = self.available

old_disable = self.disabled

self.disabled = not enable
yield

self.disabled = old_disable
9 changes: 8 additions & 1 deletion vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,14 @@ def graph_capture(
stream.wait_stream(curr_stream)

with torch.cuda.stream(stream), maybe_ca_context:
yield graph_capture_context
pynccl_comm = self.pynccl_comm
maybe_pynccl_context: Any
if not pynccl_comm:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state()
with maybe_pynccl_context:
yield graph_capture_context

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
Expand Down

0 comments on commit 9be84c0

Please sign in to comment.