Skip to content

Commit

Permalink
[xla:gpu] Do not use ncclSend and ncclRecv directly and use NcclApi
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599035702
  • Loading branch information
ezhulenev authored and copybara-github committed Jan 17, 2024
1 parent 87dd8fe commit f842a3e
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 82 deletions.
120 changes: 38 additions & 82 deletions xla/service/gpu/nccl_all_to_all_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ limitations under the License.
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/strings/substitute.h"
#include "mlir/IR/Value.h" // from @llvm-project
#include "xla/hlo/ir/hlo_instruction.h"
Expand All @@ -33,11 +32,11 @@ limitations under the License.
#include "xla/service/gpu/nccl_collective_thunk.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_memory.h"
#include "tsl/platform/errors.h"

#if XLA_ENABLE_XCCL
#include "xla/stream_executor/gpu/gpu_stream.h"
#endif
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -151,60 +150,44 @@ absl::Status NcclAllToAllStartThunk::RunNcclCollective(
absl::Status RunAllToAll(bool has_split_dimension,
std::vector<DeviceBufferPair>& buffers,
se::Stream& stream, ncclComm_t comm) {
#if XLA_ENABLE_XCCL
int device_ordinal = stream.parent()->device_ordinal();
VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal;

se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);

TF_ASSIGN_OR_RETURN(
int32_t num_participants,
NcclApi::CommCount(reinterpret_cast<NcclApi::NcclCommHandle>(comm)));

TF_RETURN_IF_ERROR(NcclApi::GroupStart());

// AllToAll can operate in two modes. Either it specifies a split dimension,
// in which case inputs are split and outputs concatenated in that dimension
// (here, we only support dimension 0), or it takes a list of inputs
// and produces a tuple of outputs.
if (has_split_dimension) {
for (size_t i = 0; i < buffers.size(); ++i) {
DeviceBufferPair& buffer = buffers[i];
const uint8_t* send_buffer =
static_cast<uint8_t*>(buffer.source_buffer.opaque());
uint8_t* recv_buffer =
static_cast<uint8_t*>(buffer.destination_buffer.opaque());

TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
ToNcclDataTypeAndCountMultiplier(
buffer.element_type, Thunk::kNcclAllToAll));
auto [dtype, multiplier] = dtype_and_multiplier;
int64_t element_count = buffer.element_count;

TF_RET_CHECK(element_count % num_participants == 0)
for (DeviceBufferPair& buffer : buffers) {
TF_RET_CHECK(buffer.element_count % num_participants == 0)
<< "Buffer was not an exact multiple of the number of participants.";
size_t chunk_elements = element_count / num_participants;
size_t chunk_bytes = chunk_elements * ShapeUtil::ByteSizeOfPrimitiveType(
buffer.element_type);

for (int rank = 0; rank < num_participants; ++rank) {
VLOG(3) << absl::StreamFormat(
"Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
"comm=%p, stream=%p)",
send_buffer + rank * chunk_bytes, chunk_elements * multiplier, rank,
static_cast<const void*>(comm), gpu_stream);
XLA_NCCL_RETURN_IF_ERROR(ncclSend(send_buffer + rank * chunk_bytes,
chunk_elements * multiplier, dtype,
rank, comm, gpu_stream));

VLOG(3) << absl::StreamFormat(
"Calling ncclRecv(recvbuff=%p, count=%d, peer=%d "
"comm=%p, stream=%p)",
recv_buffer + rank * chunk_bytes, chunk_elements * multiplier, rank,
static_cast<const void*>(comm), gpu_stream);

XLA_NCCL_RETURN_IF_ERROR(ncclRecv(recv_buffer + rank * chunk_bytes,
chunk_elements * multiplier, dtype,
rank, comm, gpu_stream));

size_t chunk_elements = buffer.element_count / num_participants;

for (int peer = 0; peer < num_participants; ++peer) {
TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase send_slice,
NcclApi::Slice(buffer.source_buffer, buffer.element_type,
peer * chunk_elements, chunk_elements));

TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase recv_slice,
NcclApi::Slice(buffer.destination_buffer, buffer.element_type,
peer * chunk_elements, chunk_elements));

TF_RETURN_IF_ERROR(NcclApi::Send(
send_slice, buffer.element_type, chunk_elements, peer,
reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));

TF_RETURN_IF_ERROR(NcclApi::Recv(
recv_slice, buffer.element_type, chunk_elements, peer,
reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));
}
}
} else {
Expand All @@ -213,45 +196,18 @@ absl::Status RunAllToAll(bool has_split_dimension,

for (size_t i = 0; i < buffers.size(); ++i) {
DeviceBufferPair& buffer = buffers[i];
const uint8_t* send_buffer =
static_cast<uint8_t*>(buffer.source_buffer.opaque());
uint8_t* recv_buffer =
static_cast<uint8_t*>(buffer.destination_buffer.opaque());

TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
ToNcclDataTypeAndCountMultiplier(
buffer.element_type, Thunk::kNcclAllToAll));
auto [dtype, multiplier] = dtype_and_multiplier;
int64_t element_count = buffer.element_count * multiplier;

VLOG(3) << absl::StreamFormat(
"Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
"comm=%p, stream=%p)",
send_buffer, element_count, i, static_cast<const void*>(comm),
gpu_stream);

XLA_NCCL_RETURN_IF_ERROR(ncclSend(send_buffer, element_count, dtype,
/*rank=*/i, comm, gpu_stream));

VLOG(3) << absl::StreamFormat(
"Calling ncclRecv(recvbuff=%p, count=%d, peer=%d "
"comm=%p, stream=%p)",
recv_buffer, element_count, i, static_cast<const void*>(comm),
gpu_stream);

XLA_NCCL_RETURN_IF_ERROR(ncclRecv(recv_buffer, element_count, dtype,
/*rank=*/i, comm, gpu_stream));

TF_RETURN_IF_ERROR(NcclApi::Send(
buffer.source_buffer, buffer.element_type, buffer.element_count, i,
reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));

TF_RETURN_IF_ERROR(NcclApi::Recv(
buffer.destination_buffer, buffer.element_type, buffer.element_count,
i, reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));
}
}
TF_RETURN_IF_ERROR(NcclApi::GroupEnd());

VLOG(3) << "Done performing all-to-all for ordinal: " << device_ordinal;
return absl::OkStatus();
#else // XLA_ENABLE_XCCL
return Unimplemented(
"NCCL support is not available: this binary was not built with a CUDA "
"compiler, which is necessary to build the NCCL source library.");
#endif // XLA_ENABLE_XCCL

return NcclApi::GroupEnd();
}

} // namespace gpu
Expand Down
43 changes: 43 additions & 0 deletions xla/service/gpu/nccl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "xla/primitive_util.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/service/gpu/nccl_clique_key.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/gpu/gpu_stream.h"
#include "xla/stream_executor/stream.h"
Expand Down Expand Up @@ -172,6 +173,14 @@ static ncclUniqueId AsNcclUniqueId(const NcclCliqueId& clique_id) {
return id;
}

absl::StatusOr<se::DeviceMemoryBase> NcclApi::Slice(se::DeviceMemoryBase buff,
PrimitiveType dtype,
size_t offset,
size_t count) {
size_t multiplier = ShapeUtil::ByteSizeOfPrimitiveType(dtype);
return buff.GetByteSlice(offset * multiplier, count * multiplier);
}

absl::StatusOr<NcclCliqueId> NcclApi::GetUniqueId() {
VLOG(3) << "Get NCCL unique id";
ncclUniqueId id;
Expand Down Expand Up @@ -289,4 +298,38 @@ absl::Status NcclApi::AllGather(se::DeviceMemoryBase send_buffer,
nccl_dtype, Cast(comm), se::gpu::AsGpuStreamValue(stream)));
}

absl::Status NcclApi::Send(se::DeviceMemoryBase send_buffer,
PrimitiveType dtype, size_t count, int32_t peer,
NcclCommHandle comm, se::Stream* stream) {
VLOG(3) << absl::StreamFormat(
"Launch NCCL Send operation on device #%d; send_buffer=%p; dtype=%s; "
"count=%d; peer=%d; comm=%p; stream=%p",
stream->parent()->device_ordinal(), send_buffer.opaque(),
primitive_util::LowercasePrimitiveTypeName(dtype), count, peer, comm,
stream);

TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false));

return XLA_NCCL_STATUS(
ncclSend(send_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype,
peer, Cast(comm), se::gpu::AsGpuStreamValue(stream)));
}

absl::Status NcclApi::Recv(se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count, int32_t peer,
NcclCommHandle comm, se::Stream* stream) {
VLOG(3) << absl::StreamFormat(
"Launch NCCL Recv operation on device #%d; recv_buffer=%p; dtype=%s; "
"count=%d; peer=%d; comm=%p; stream=%p",
stream->parent()->device_ordinal(), recv_buffer.opaque(),
primitive_util::LowercasePrimitiveTypeName(dtype), count, peer, comm,
stream);

TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false));

return XLA_NCCL_STATUS(
ncclRecv(recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype,
peer, Cast(comm), se::gpu::AsGpuStreamValue(stream)));
}

} // namespace xla::gpu
21 changes: 21 additions & 0 deletions xla/service/gpu/nccl_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ struct NcclApi {
// Convenience handles for defining API functions.
using NcclCommHandle = NcclComm*;

// Returns a slice of device memory `buff` containing `count` values of data
// type `dtype` starting from `offset`.
static absl::StatusOr<se::DeviceMemoryBase> Slice(se::DeviceMemoryBase buff,
PrimitiveType dtype,
size_t offset,
size_t count);

// Creates a new unique clique id.
//
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclgetuniqueid
Expand Down Expand Up @@ -111,6 +118,20 @@ struct NcclApi {
se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count,
NcclCommHandle comm, se::Stream* stream);

// Send data from `send_buff` to rank `peer`.
//
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend
static absl::Status Send(se::DeviceMemoryBase send_buffer,
PrimitiveType dtype, size_t count, int32_t peer,
NcclCommHandle comm, se::Stream* stream);

// Receive data from rank `peer` into `recv_buff`.
//
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv
static absl::Status Recv(se::DeviceMemoryBase recv_buffer,
PrimitiveType dtype, size_t count, int32_t peer,
NcclCommHandle comm, se::Stream* stream);
};

//===----------------------------------------------------------------------===//
Expand Down
16 changes: 16 additions & 0 deletions xla/service/gpu/nccl_api_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ limitations under the License.

namespace xla::gpu {

absl::StatusOr<se::DeviceMemoryBase> NcclApi::Slice(se::DeviceMemoryBase,
PrimitiveType, size_t,
size_t) {
return absl::UnimplementedError("XLA compiled without NCCL support");
}

absl::StatusOr<NcclCliqueId> NcclApi::GetUniqueId() {
return absl::UnimplementedError("XLA compiled without NCCL support");
}
Expand Down Expand Up @@ -74,4 +80,14 @@ absl::Status NcclApi::AllGather(se::DeviceMemoryBase, se::DeviceMemoryBase,
return absl::UnimplementedError("XLA compiled without NCCL support");
}

absl::Status NcclApi::Send(se::DeviceMemoryBase, PrimitiveType, size_t, int32_t,
NcclCommHandle, se::Stream*) {
return absl::UnimplementedError("XLA compiled without NCCL support");
}

absl::Status NcclApi::Recv(se::DeviceMemoryBase, PrimitiveType, size_t, int32_t,
NcclCommHandle, se::Stream*) {
return absl::UnimplementedError("XLA compiled without NCCL support");
}

} // namespace xla::gpu

0 comments on commit f842a3e

Please sign in to comment.