Skip to content

Commit 5f6e999

Browse files
committed
first pass at exposing cuBlasMp to TE/PyTorch
Signed-off-by: Alp Dener <adener@nvidia.com>
1 parent f62cad9 commit 5f6e999

File tree

9 files changed

+374
-326
lines changed

9 files changed

+374
-326
lines changed

transformer_engine/common/include/transformer_engine/comm_gemm.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ enum NVTECommGemmAlgoType {
4444
kNVTECommGemmAlgoAtomicMulticast = 4
4545
};
4646

47+
bool nvte_built_with_cublasmp() {
48+
#ifdef NVTE_WITH_CUBLASMP
49+
return true;
50+
#else
51+
return false;
52+
#endif
53+
}
54+
4755
/*! \brief Create a comm-gemm context.
4856
*
4957
* \param[in] comm NCCL communicator.

transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ namespace transformer_engine {
2626
*/
2727
bool ubuf_built_with_mpi();
2828

29+
enum class CommOverlapMethod { BULK = 0, PIPELINE = 1, RING_EXCHANGE = 2 };
30+
2931
enum class CommOverlapType { RS = 0, AG = 1 };
3032

3133
enum class CommOverlapAlgo {

transformer_engine/common/util/pybind_helper.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <pybind11/pybind11.h>
1111
#include <transformer_engine/comm_gemm_overlap.h>
12+
#include <transformer_engine/comm_gemm.h>
1213
#include <transformer_engine/fused_attn.h>
1314
#include <transformer_engine/transformer_engine.h>
1415

@@ -84,6 +85,11 @@
8485
m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \
8586
.value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \
8687
.value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \
88+
pybind11::enum_<transformer_engine::CommOverlapMethod>(m, "CommOverlapMethod", \
89+
pybind11::module_local()) \
90+
.value("BULK", transformer_engine::CommOverlapMethod::BULK) \
91+
.value("PIPELINE", transformer_engine::CommOverlapMethod::PIPELINE) \
92+
.value("RING_EXCHANGE", transformer_engine::CommOverlapMethod::RING_EXCHANGE); \
8793
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \
8894
pybind11::module_local()) \
8995
.value("RS", transformer_engine::CommOverlapType::RS) \
@@ -135,6 +141,8 @@
135141
}, \
136142
py::call_guard<py::gil_scoped_release>(), py::arg("device_id") = -1); \
137143
m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \
144+
py::call_guard<py::gil_scoped_release>()); \
145+
m.def("nvte_built_with_cublasmp", &nvte_built_with_cublasmp, \
138146
py::call_guard<py::gil_scoped_release>());
139147

140148
#endif

transformer_engine/pytorch/csrc/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <transformer_engine/cast.h>
2828
#include <transformer_engine/cast_transpose_noop.h>
2929
#include <transformer_engine/comm_gemm_overlap.h>
30+
#include <transformer_engine/comm_gemm.h>
3031
#include <transformer_engine/fused_attn.h>
3132
#include <transformer_engine/fused_rope.h>
3233
#include <transformer_engine/fused_router.h>

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
128128
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
129129
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
130130
at::Tensor workspace, size_t workspaceSize, bool accumulate,
131-
bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
131+
bool use_split_accumulator, CommOverlapManager *comm_overlap = nullptr,
132132
std::optional<CommOverlapType> comm_type = std::nullopt,
133133
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false,
134134
float alpha = 1.0f, std::optional<float> beta = std::nullopt);
@@ -504,48 +504,56 @@ class CommOverlapHelper : torch::CustomClassHolder {
504504

505505
CommOverlapHelper();
506506

507-
CommOverlapHelper(c10d::ProcessGroup *world_group,
508-
std::optional<c10d::ProcessGroup *> intra_node_group);
507+
CommOverlapHelper(c10d::ProcessGroup *tp_group);
508+
509+
CommOverlapHelper(c10d::ProcessGroup *world_group, c10d::ProcessGroup *intra_node_group);
509510

510511
~CommOverlapHelper();
511512

512513
void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
513514
ExtComm comm);
514515

515516
void ub_barrier(ExtComm comm);
516-
};
517-
518-
class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase {
519-
public:
520-
CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
521-
CommOverlapHelper *helper, int tp_size, int num_splits = 3,
522-
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
523-
int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16,
524-
bool set_sm_margin = true, bool atomic_gemm = false,
525-
bool rs_overlap_first_gemm = false);
526-
527-
~CommOverlap() {}
528-
529-
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
530517

531-
at::Tensor get_buffer(bool local_chunk = false,
532-
std::optional<std::vector<int64_t>> shape = std::nullopt);
533-
534-
std::pair<at::Stream, at::Stream> get_communication_stream();
518+
int64_t get_comm_ptr(std::string group = "world") { return pgs[group]->getCommPtr(); }
519+
};
535520

536-
}; // CommOverlap
521+
class CommOverlapManager : torch::CustomClassHolder {
522+
private:
523+
#ifndef NVTE_WITH_CUBLASMP
524+
transformer_engine::CommOverlapCore *_ctx;
525+
#else
526+
CommGemmCtx *_ctx;
527+
#endif
528+
transformer_engine::CommOverlapMethod _method;
529+
int _num_comm_sm;
530+
bool _use_atomic_gemm;
537531

538-
class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase {
539532
public:
540-
CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
541-
CommOverlapHelper *helper, int tp_size,
542-
transformer_engine::CommOverlapType comm_type,
543-
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
544-
int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3,
545-
bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true,
546-
bool aggregate = false);
547-
548-
~CommOverlapP2P() {}
533+
CommOverlapManager(transformer_engine::CommOverlapMethod method,
534+
transformer_engine::CommOverlapType comm_type,
535+
const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
536+
CommOverlapHelper *helper, int tp_size, int num_splits = 3,
537+
int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
538+
int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16,
539+
bool set_sm_margin = false, bool atomic_gemm = false,
540+
bool aggregate_ag = false, bool rs_overlap_first_gemm = false);
541+
542+
~CommOverlapManager() {
543+
#ifdef NVTE_WITH_CUBLASMP
544+
nvte_comm_gemm_ctx_destroy(_ctx);
545+
#else
546+
delete _ctx;
547+
#endif;
548+
}
549+
550+
bool is_fp8_ubuf() {
551+
#ifndef NVTE_WITH_CUBLASMP
552+
return _ctx->is_fp8_ubuf();
553+
#else
554+
return false;
555+
#endif
556+
}
549557

550558
void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
551559

@@ -554,6 +562,11 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
554562

555563
std::pair<at::Stream, at::Stream> get_communication_stream();
556564

557-
}; // CommOverlapP2P
565+
void execute(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
566+
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
567+
TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator,
568+
transformer_engine::CommOverlapType comm_type, TensorWrapper &aux_out,
569+
cudaStream_t stream);
570+
}; // CommOverlapManager
558571

559572
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_

0 commit comments

Comments
 (0)