Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 69 additions & 56 deletions csrc/trtllm_mnnvl_allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,77 +26,90 @@ using tvm::ffi::Optional;
} \
}()

void trtllm_mnnvl_all_reduce(TensorView in, int64_t multicast_buffer_ptr, int64_t buffer_ptrs_dev,
int64_t buffer_M, TensorView buffer_flags_mnnvl, int64_t nranks,
int64_t rank, bool wait_for_results, bool launch_with_pdl,
Optional<TensorView> out) {
cudaSetDevice(in.device().device_id);
auto stream = get_stream(in.device());
void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_ptr,
int64_t buffer_ptrs_dev, int64_t buffer_ptr_local,
TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank,
bool rmsnorm_fusion, bool launch_with_pdl, bool use_oneshot,
TensorView output, Optional<TensorView> residual_out,
Optional<TensorView> residual_in, Optional<TensorView> gamma,
Optional<double> epsilon) {
cudaSetDevice(input.device().device_id);
auto stream = get_stream(input.device());

DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(in.dtype(), c_type, [&] {
DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(input.dtype(), c_type, [&] {
// Extract parameters from tensors
int64_t num_tokens = in.size(0);
int64_t token_dim = in.size(1);
int64_t num_tokens = input.size(0);
int64_t token_dim = input.size(1);

// Validate input parameters
TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float2) / sizeof(c_type)), 0)
<< "token_dim must be divisible by " << sizeof(float2) / sizeof(c_type);
TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float4) / sizeof(c_type)), 0)
<< "token_dim must be divisible by " << sizeof(float4) / sizeof(c_type);
TVM_FFI_ICHECK(output.size(0) == input.size(0) && output.size(1) == input.size(1))
<< "output shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
<< ") but got (" << output.size(0) << ", " << output.size(1) << ")";
TVM_FFI_ICHECK(nranks >= 2 && nranks <= 64)
<< "nranks must be between 2 and 64, got " << nranks;
TVM_FFI_ICHECK(rank >= 0 && rank < nranks)
<< "rank must be between 0 and nranks-1, got " << rank;
TVM_FFI_ICHECK(out.has_value() || !wait_for_results)
<< "out tensor must be provided if wait_for_results is true";
TVM_FFI_ICHECK((residual_in.has_value() && residual_out.has_value() && gamma.has_value() &&
epsilon.has_value()) ||
!rmsnorm_fusion)
<< "residual_in, residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is "
"true";

if (rmsnorm_fusion) {
TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens &&
residual_in.value().size(1) == token_dim)
<< "residual_in shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
<< ") but got (" << residual_in.value().size(0) << ", " << residual_in.value().size(1)
<< ")";
TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens &&
residual_out.value().size(1) == token_dim)
<< "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
<< ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1)
<< ")";
TVM_FFI_ICHECK(gamma.value().size(0) == token_dim)
<< "gamma must have the same shape as token dimension (" << token_dim << ") but got ("
<< gamma.value().size(0) << ")";
}

// Create the parameters struct
AllReduceParams<c_type> params;
params.nranks = nranks;
params.rank = rank;
params.buffer_M = buffer_M;
params.num_tokens = num_tokens;
params.token_dim = token_dim;
params.buffer_ptrs_dev = reinterpret_cast<void**>(buffer_ptrs_dev);
params.multicast_ptr = reinterpret_cast<void*>(multicast_buffer_ptr);
params.buffer_flags = buffer_flags_mnnvl.data_ptr();
params.wait_for_results = wait_for_results;
params.launch_with_pdl = launch_with_pdl;
params.input = in.data_ptr();
params.output = out.has_value() ? out.value().data_ptr() : nullptr;
params.stream = stream;
AllReduceFusionParams params;

auto status = twoshot_allreduce_dispatch_world_size<c_type>(params);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "twoshot_allreduce_dispatch_world_size failed with error code "
<< cudaGetErrorString(status);
});
}
// Aux Information
params.nRanks = nranks;
params.rank = rank;
params.numTokens = num_tokens;
params.tokenDim = token_dim;
params.bufferPtrsDev = reinterpret_cast<void**>(buffer_ptrs_dev);
params.bufferPtrLocal = reinterpret_cast<void*>(buffer_ptr_local);
params.multicastPtr = reinterpret_cast<void*>(multicast_buffer_ptr);
params.bufferFlags = reinterpret_cast<uint32_t*>(buffer_flags_mnnvl.data_ptr());
params.rmsNormFusion = rmsnorm_fusion;
params.launchWithPdl = launch_with_pdl;

void trtllm_mnnvl_rmsnorm(int64_t multicast_buffer_ptr, TensorView prenorm_output,
TensorView normed_output, TensorView gamma, double epsilon,
TensorView residual, TensorView buffer_flags, bool launch_with_pdl) {
cudaSetDevice(prenorm_output.device().device_id);
auto stream = get_stream(prenorm_output.device());
// input data
params.input = const_cast<void const*>(input.data_ptr());
params.residualIn =
residual_in.has_value() ? const_cast<void const*>(residual_in.value().data_ptr()) : nullptr;
params.gamma = gamma.has_value() ? const_cast<void const*>(gamma.value().data_ptr()) : nullptr;
params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5;

DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(prenorm_output.dtype(), c_type, [&] {
// Create the parameters struct
RMSNormParams<c_type> params;
params.residual_output = prenorm_output.data_ptr();
params.output = normed_output.data_ptr();
params.input = reinterpret_cast<void const*>(multicast_buffer_ptr);
params.gamma = gamma.data_ptr();
params.epsilon = epsilon;
params.residual = residual.data_ptr();
params.buffer_flags = reinterpret_cast<uint32_t*>(buffer_flags.data_ptr());
params.batch = normed_output.size(0);
params.hidden_dim = normed_output.size(1);
// output data
params.output = const_cast<void*>(output.data_ptr());
params.residualOut =
residual_out.has_value() ? const_cast<void*>(residual_out.value().data_ptr()) : nullptr;
params.stream = stream;
params.launch_with_pdl = launch_with_pdl;
auto status = twoshot_rmsnorm_dispatch_hidden_dim<c_type>(params);

cudaError_t status;
if (use_oneshot) {
status = oneshotAllreduceFusionDispatch<c_type>(params);
} else {
status = twoshotAllreduceFusionDispatch<c_type>(params);
}
TVM_FFI_ICHECK(status == cudaSuccess)
<< "twoshot_rmsnorm_dispatch_hidden_dim failed with error code "
<< cudaGetErrorString(status);
<< "trtllm_mnnvl_allreduce_fusion failed with error code " << cudaGetErrorString(status);
});
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_all_reduce, trtllm_mnnvl_all_reduce);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_rmsnorm, trtllm_mnnvl_rmsnorm);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_allreduce_fusion, trtllm_mnnvl_allreduce_fusion);
Loading