Skip to content

Commit b014608

Browse files
committed
refine
1 parent c56c527 commit b014608

File tree

4 files changed

+6
-4
lines changed

4 files changed

+6
-4
lines changed

oneflow/user/kernels/collective_communication/cpu/cpu_recv.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class CpuRecvImpl final : public Recv {
4141
}
4242

4343
void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src,
44-
const CclComm& ccl_comm) const override {
44+
CclComm ccl_comm) const override {
4545
Launch(stream, out, elem_cnt, src);
4646
}
4747

oneflow/user/kernels/collective_communication/cuda/cuda_send.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class CudaSend final : public Send {
3434
void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst) const override {
3535
#if HAS_NCCL_SEND_RECV
3636
const auto& comm_and_peer_rank = GetNcclCommAndPeerNcclRank(dst);
37+
printf("\n CudaSend >>> Launch >>> communication_ctx");
3738
OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, comm_and_peer_rank.second,
3839
comm_and_peer_rank.first, stream->As<ep::CudaStream>()->cuda_stream()));
3940
#else
@@ -42,14 +43,15 @@ class CudaSend final : public Send {
4243
}
4344

4445
void Launch(ep::Stream* stream, const void* in, size_t elem_cnt, int64_t dst,
45-
CclComm ccl_comm) const override {
46+
ccl::CclComm ccl_comm) const override {
4647
#if HAS_NCCL_SEND_RECV
4748
ncclComm_t* comm = reinterpret_cast<ncclComm_t*>(ccl_comm.getComm());
4849
OF_NCCL_CHECK(ncclSend(in, elem_cnt, nccl_datatype_, dst, *comm,
4950
stream->As<ep::CudaStream>()->cuda_stream()));
5051
#else
5152
UNIMPLEMENTED() << "GPU send is only supported when nccl version >= 2.7"
5253
#endif // HAS_NCCL_SEND_RECV
54+
printf("\n CudaSend >>> Launch >>> ccl::CclComm");
5355
}
5456

5557
private:

oneflow/user/kernels/collective_communication/include/recv.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class Recv : public CollectiveCommunication {
3333
virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src) const = 0;
3434

3535
virtual void Launch(ep::Stream* stream, void* out, size_t elem_cnt, int64_t src,
36-
const CclComm& ccl_comm) const = 0;
36+
CclComm ccl_comm) const = 0;
3737
};
3838

3939
inline bool IsRecvRegistered(DeviceType device_type) {

oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState {
4242
}
4343
~NcclLogical2DSameDim0KernelCommState() override = default;
4444

45-
const ccl::CclComm& ccl_comm() const {
45+
ccl::CclComm ccl_comm() {
4646
if (!is_init_) { Init(); }
4747
return ccl_comm_;
4848
}

0 commit comments

Comments
 (0)