Skip to content

Commit

Permalink
Revert "[distributed] remove pynccl's redundant stream (vllm-project#…
Browse files Browse the repository at this point in the history
…11744)"

This reverts commit 635b897.
  • Loading branch information
gshtras committed Jan 8, 2025
1 parent 9be84c0 commit 177ad85
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
5 changes: 3 additions & 2 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ def worker_fn_with_cudagraph():
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
torch.cuda.synchronize()
with torch.cuda.graph(graph), \
pynccl_comm.change_state(enable=True):
with torch.cuda.graph(
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
enable=True):
a_out = pynccl_comm.all_reduce(a)
torch.cuda.synchronize()
graph.replay()
Expand Down
28 changes: 19 additions & 9 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
if self.world_size == 1:
self.available = False
self.disabled = True
self.stream = None
return
try:
self.nccl = NCCLLibrary(library_path)
Expand All @@ -59,6 +60,7 @@ def __init__(
# e.g. in a non-GPU environment
self.available = False
self.disabled = True
self.stream = None
return

self.available = True
Expand Down Expand Up @@ -96,12 +98,12 @@ def __init__(
with torch.cuda.device(device):
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank)
self.stream = torch.cuda.Stream()

stream = torch.cuda.current_stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
stream.synchronize()
self.stream.synchronize()
del data

def all_reduce(self,
Expand All @@ -120,7 +122,7 @@ def all_reduce(self,
out_tensor = torch.empty_like(in_tensor)

if stream is None:
stream = torch.cuda.current_stream()
stream = self.stream
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
Expand All @@ -142,7 +144,7 @@ def all_gather(self,
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = self.stream
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
Expand All @@ -163,7 +165,7 @@ def reduce_scatter(self,
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = self.stream
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
Expand All @@ -178,7 +180,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None):
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = self.stream
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))
Expand All @@ -190,7 +192,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = self.stream
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
Expand All @@ -202,7 +204,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = self.stream
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer
Expand All @@ -215,17 +217,25 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
self.comm, cudaStream_t(stream.cuda_stream))

@contextmanager
def change_state(self, enable: Optional[bool] = None):
def change_state(self,
enable: Optional[bool] = None,
stream: Optional[torch.cuda.Stream] = None):
"""
A context manager to change the state of the communicator.
"""
if enable is None:
# guess a default value when not specified
enable = self.available

if stream is None:
stream = self.stream

old_disable = self.disabled
old_stream = self.stream

self.stream = stream
self.disabled = not enable
yield

self.disabled = old_disable
self.stream = old_stream
3 changes: 2 additions & 1 deletion vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ def graph_capture(
if not pynccl_comm:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state()
maybe_pynccl_context = pynccl_comm.change_state(
stream=torch.cuda.current_stream())
with maybe_pynccl_context:
yield graph_capture_context

Expand Down

0 comments on commit 177ad85

Please sign in to comment.