From 177ad8538e18e5a59a219d844b7117a9e225a30c Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 8 Jan 2025 21:12:34 +0000 Subject: [PATCH] Revert "[distributed] remove pynccl's redundant stream (#11744)" This reverts commit 635b897246da121238454ed4b2bbc87cb4d4166b. --- tests/distributed/test_pynccl.py | 5 ++-- .../device_communicators/pynccl.py | 28 +++++++++++++------ vllm/distributed/parallel_state.py | 3 +- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index a77b48d5e49f3..36cfe42251384 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -137,8 +137,9 @@ def worker_fn_with_cudagraph(): # run something in the default stream to initialize torch engine a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph), \ - pynccl_comm.change_state(enable=True): + with torch.cuda.graph( + graph, stream=pynccl_comm.stream), pynccl_comm.change_state( + enable=True): a_out = pynccl_comm.all_reduce(a) torch.cuda.synchronize() graph.replay() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 93d96fd8f5686..a6800f93f167b 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -51,6 +51,7 @@ def __init__( if self.world_size == 1: self.available = False self.disabled = True + self.stream = None return try: self.nccl = NCCLLibrary(library_path) @@ -59,6 +60,7 @@ def __init__( # e.g. in a non-GPU environment self.available = False self.disabled = True + self.stream = None return self.available = True @@ -96,12 +98,12 @@ def __init__( with torch.cuda.device(device): self.comm: ncclComm_t = self.nccl.ncclCommInitRank( self.world_size, self.unique_id, self.rank) + self.stream = torch.cuda.Stream() - stream = torch.cuda.current_stream() # A small all_reduce for warmup. data = torch.zeros(1, device=device) self.all_reduce(data) - stream.synchronize() + self.stream.synchronize() del data def all_reduce(self, @@ -120,7 +122,7 @@ def all_reduce(self, out_tensor = torch.empty_like(in_tensor) if stream is None: - stream = torch.cuda.current_stream() + stream = self.stream self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), buffer_type(out_tensor.data_ptr()), in_tensor.numel(), @@ -142,7 +144,7 @@ def all_gather(self, f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = self.stream self.nccl.ncclAllGather( buffer_type(input_tensor.data_ptr()), buffer_type(output_tensor.data_ptr()), input_tensor.numel(), @@ -163,7 +165,7 @@ def reduce_scatter(self, f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = self.stream self.nccl.ncclReduceScatter( buffer_type(input_tensor.data_ptr()), buffer_type(output_tensor.data_ptr()), output_tensor.numel(), @@ -178,7 +180,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = self.stream self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), dst, self.comm, cudaStream_t(stream.cuda_stream)) @@ -190,7 +192,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = self.stream self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) @@ -202,7 +204,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = torch.cuda.current_stream() + stream = self.stream if src == self.rank: sendbuff = buffer_type(tensor.data_ptr()) # NCCL requires the sender also to have a receive buffer @@ -215,7 +217,9 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): self.comm, cudaStream_t(stream.cuda_stream)) @contextmanager - def change_state(self, enable: Optional[bool] = None): + def change_state(self, + enable: Optional[bool] = None, + stream: Optional[torch.cuda.Stream] = None): """ A context manager to change the state of the communicator. """ @@ -223,9 +227,15 @@ def change_state(self, enable: Optional[bool] = None): # guess a default value when not specified enable = self.available + if stream is None: + stream = self.stream + old_disable = self.disabled + old_stream = self.stream + self.stream = stream self.disabled = not enable yield self.disabled = old_disable + self.stream = old_stream diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index dccd3addbcb35..a0d4235460f3b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -310,7 +310,8 @@ def graph_capture( if not pynccl_comm: maybe_pynccl_context = nullcontext() else: - maybe_pynccl_context = pynccl_comm.change_state() + maybe_pynccl_context = pynccl_comm.change_state( + stream=torch.cuda.current_stream()) with maybe_pynccl_context: yield graph_capture_context