@@ -128,7 +128,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
128128 py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
129129 DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
130130 at::Tensor workspace, size_t workspaceSize, bool accumulate,
131- bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr ,
131+ bool use_split_accumulator, CommOverlapManager *comm_overlap = nullptr ,
132132 std::optional<CommOverlapType> comm_type = std::nullopt ,
133133 MaybeTensor extra_output = std::nullopt , bool bulk_overlap = false ,
134134 float alpha = 1 .0f , std::optional<float > beta = std::nullopt );
@@ -504,48 +504,56 @@ class CommOverlapHelper : torch::CustomClassHolder {
504504
505505 CommOverlapHelper ();
506506
507- CommOverlapHelper (c10d::ProcessGroup *world_group,
508- std::optional<c10d::ProcessGroup *> intra_node_group);
507+ CommOverlapHelper (c10d::ProcessGroup *tp_group);
508+
509+ CommOverlapHelper (c10d::ProcessGroup *world_group, c10d::ProcessGroup *intra_node_group);
509510
510511 ~CommOverlapHelper ();
511512
512513 void ub_allgather (void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
513514 ExtComm comm);
514515
515516 void ub_barrier (ExtComm comm);
516- };
517-
518- class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase {
519- public:
520- CommOverlap (const std::vector<size_t > &buffer_shape, at::ScalarType buffer_dtype,
521- CommOverlapHelper *helper, int tp_size, int num_splits = 3 ,
522- int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2 ,
523- int gemm_priority = 0 , int comm_priority = 0 , int num_comm_sm = 16 ,
524- bool set_sm_margin = true , bool atomic_gemm = false ,
525- bool rs_overlap_first_gemm = false );
526-
527- ~CommOverlap () {}
528-
529- void copy_into_buffer (const at::Tensor &input, bool local_chunk = false );
530517
531- at::Tensor get_buffer (bool local_chunk = false ,
532- std::optional<std::vector<int64_t >> shape = std::nullopt );
533-
534- std::pair<at::Stream, at::Stream> get_communication_stream ();
518+ int64_t get_comm_ptr (std::string group = " world" ) { return pgs[group]->getCommPtr (); }
519+ };
535520
536- }; // CommOverlap
521+ class CommOverlapManager : torch::CustomClassHolder {
522+ private:
523+ #ifndef NVTE_WITH_CUBLASMP
524+ transformer_engine::CommOverlapCore *_ctx;
525+ #else
526+ CommGemmCtx *_ctx;
527+ #endif
528+ transformer_engine::CommOverlapMethod _method;
529+ int _num_comm_sm;
530+ bool _use_atomic_gemm;
537531
538- class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase {
539532 public:
540- CommOverlapP2P (const std::vector<size_t > &buffer_shape, at::ScalarType buffer_dtype,
541- CommOverlapHelper *helper, int tp_size,
542- transformer_engine::CommOverlapType comm_type,
543- int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2 ,
544- int gemm_priority = 0 , int comm_priority = 0 , int num_comm_sm = 3 ,
545- bool set_sm_margin = true , bool atomic_gemm = false , bool use_ce = true ,
546- bool aggregate = false );
547-
548- ~CommOverlapP2P () {}
533+ CommOverlapManager (transformer_engine::CommOverlapMethod method,
534+ transformer_engine::CommOverlapType comm_type,
535+ const std::vector<size_t > &buffer_shape, at::ScalarType buffer_dtype,
536+ CommOverlapHelper *helper, int tp_size, int num_splits = 3 ,
537+ int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2 ,
538+ int gemm_priority = 0 , int comm_priority = 0 , int num_comm_sm = 16 ,
539+ bool set_sm_margin = false , bool atomic_gemm = false ,
540+ bool aggregate_ag = false , bool rs_overlap_first_gemm = false );
541+
542+ ~CommOverlapManager () {
543+ #ifdef NVTE_WITH_CUBLASMP
544+ nvte_comm_gemm_ctx_destroy (_ctx);
545+ #else
546+ delete _ctx;
547+ #endif ;
548+ }
549+
550+ bool is_fp8_ubuf () {
551+ #ifndef NVTE_WITH_CUBLASMP
552+ return _ctx->is_fp8_ubuf ();
553+ #else
554+ return false ;
555+ #endif
556+ }
549557
550558 void copy_into_buffer (const at::Tensor &input, bool local_chunk = false );
551559
@@ -554,6 +562,11 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
554562
555563 std::pair<at::Stream, at::Stream> get_communication_stream ();
556564
557- }; // CommOverlapP2P
565+ void execute (const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
566+ TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
567+ TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator,
568+ transformer_engine::CommOverlapType comm_type, TensorWrapper &aux_out,
569+ cudaStream_t stream);
570+ }; // CommOverlapManager
558571
559572#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
0 commit comments