diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index d3af3103..7357bf05 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -377,6 +377,18 @@ class Connection { virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) = 0; + /// Write data from a source @ref RegisteredMemory to a destination @ref RegisteredMemory in a 2D fashion. + /// + /// @param dst The destination @ref RegisteredMemory. + /// @param dstOffset The offset in bytes from the start of the destination @ref RegisteredMemory. + /// @param dstPitch The pitch of the destination @ref RegisteredMemory in bytes. + /// @param src The source @ref RegisteredMemory. + /// @param srcOffset The offset in bytes from the start of the source @ref RegisteredMemory. + /// @param srcPitch The pitch of the source @ref RegisteredMemory in bytes. + /// @param width The width of the 2D region to write in bytes. + /// @param height The height of the 2D region. + virtual void write2D(RegisteredMemory dst, uint64_t dstOffset, uint64_t dstPitch, RegisteredMemory src, + uint64_t srcOffset, uint64_t srcPitch, uint64_t width, uint64_t height) = 0; /// Update a 8-byte value in a destination @ref RegisteredMemory and synchronize the change with the remote process. /// /// @param dst The destination @ref RegisteredMemory. diff --git a/include/mscclpp/proxy_channel.hpp b/include/mscclpp/proxy_channel.hpp index 2c644648..67479224 100644 --- a/include/mscclpp/proxy_channel.hpp +++ b/include/mscclpp/proxy_channel.hpp @@ -29,10 +29,18 @@ class ProxyService : public BaseProxyService { ProxyService(); /// Build and add a semaphore to the proxy service. + /// @param communicator The communicator for bootstrapping. /// @param connection The connection associated with the semaphore. /// @return The ID of the semaphore. SemaphoreId buildAndAddSemaphore(Communicator& communicator, std::shared_ptr connection); + /// Build and add a semaphore with pitch to the proxy service. This is used for 2D transfers. + /// @param communicator The communicator for bootstrapping. + /// @param connection The connection associated with the channel. + /// @param pitch The pitch pair. + SemaphoreId buildAndAddSemaphore(Communicator& communicator, std::shared_ptr connection, + std::pair pitch); + /// Add a semaphore to the proxy service. /// @param semaphore The semaphore to be added /// @return The ID of the semaphore. @@ -62,6 +70,7 @@ class ProxyService : public BaseProxyService { private: std::vector> semaphores_; std::vector memories_; + std::vector> pitches_; Proxy proxy_; int deviceNumaNode; diff --git a/include/mscclpp/proxy_channel_device.hpp b/include/mscclpp/proxy_channel_device.hpp index db90eac7..23b696a7 100644 --- a/include/mscclpp/proxy_channel_device.hpp +++ b/include/mscclpp/proxy_channel_device.hpp @@ -27,6 +27,10 @@ const TriggerType TriggerSync = 0x4; // Trigger a flush. #define MSCCLPP_BITS_CONNID 10 #define MSCCLPP_BITS_FIFO_RESERVED 1 +#define MSCCLPP_BITS_WIDTH_SIZE 16 +#define MSCCLPP_BITS_HEIGHT_SIZE 16 +#define MSCCLPP_2D_FLAG 1 + /// Basic structure of each work element in the FIFO. union ChannelTrigger { ProxyTrigger value; @@ -47,6 +51,25 @@ union ChannelTrigger { uint64_t reserved : MSCCLPP_BITS_FIFO_RESERVED; } fields; + struct { + // First 64 bits: value[0] + uint64_t width : MSCCLPP_BITS_WIDTH_SIZE; + uint64_t height : MSCCLPP_BITS_HEIGHT_SIZE; + uint64_t srcOffset : MSCCLPP_BITS_OFFSET; + uint64_t + : (64 - MSCCLPP_BITS_WIDTH_SIZE - MSCCLPP_BITS_HEIGHT_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment + // Second 64 bits: value[1] + uint64_t dstOffset : MSCCLPP_BITS_OFFSET; + uint64_t srcMemoryId : MSCCLPP_BITS_REGMEM_HANDLE; + uint64_t dstMemoryId : MSCCLPP_BITS_REGMEM_HANDLE; + uint64_t type : MSCCLPP_BITS_TYPE; + uint64_t chanId : MSCCLPP_BITS_CONNID; + uint64_t multiDimensionFlag : MSCCLPP_2D_FLAG; + uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_TYPE - + MSCCLPP_BITS_CONNID - MSCCLPP_2D_FLAG - MSCCLPP_BITS_FIFO_RESERVED); // ensure 64-bit alignment + uint64_t reserved : MSCCLPP_BITS_FIFO_RESERVED; + } fields2D; + #ifdef __CUDACC__ /// Default constructor. __forceinline__ __device__ ChannelTrigger() {} @@ -71,6 +94,27 @@ union ChannelTrigger { << MSCCLPP_BITS_OFFSET) + dstOffset); } + + /// Constructor. + /// @param type The type of the trigger. + /// @param dst The destination memory region. + /// @param dstOffset The offset into the destination memory region. + /// @param src The source memory region. + /// @param srcOffset The offset into the source memory region. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + /// @param semaphoreId The ID of the semaphore. + __device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, + uint64_t width, uint64_t height, int semaphoreId) { + value.fst = (((srcOffset << MSCCLPP_BITS_HEIGHT_SIZE) + height) << MSCCLPP_BITS_WIDTH_SIZE) + width; + value.snd = ((((((((((1ULL << MSCCLPP_BITS_CONNID) + semaphoreId) << MSCCLPP_BITS_TYPE) + type) + << MSCCLPP_BITS_REGMEM_HANDLE) + + dst) + << MSCCLPP_BITS_REGMEM_HANDLE) + + src) + << MSCCLPP_BITS_OFFSET) + + dstOffset); + } #endif // __CUDACC__ }; @@ -104,6 +148,28 @@ struct ProxyChannelDeviceHandle { put(dst, offset, src, offset, size); } + /// @brief Push a @ref TriggerData to the FIFO. + /// @param dst The destination memory region. + /// @param dstOffset The offset into the destination memory region. + /// @param src The source memory region. + /// @param srcOffset The offset into the source memory region. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2D(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, + uint32_t width, uint32_t height) { + fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, width, height, semaphoreId_).value); + } + + /// @brief Push a @ref TriggerData to the FIFO. + /// @param dst The destination memory region. + /// @param src The source memory region. + /// @param offset The common offset into the destination and source memory regions. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2D(MemoryId dst, MemoryId src, uint64_t offset, uint32_t width, uint32_t height) { + put2D(dst, offset, src, offset, width, height); + } + /// Push a @ref TriggerFlag to the FIFO. __forceinline__ __device__ void signal() { fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value); @@ -120,6 +186,19 @@ struct ProxyChannelDeviceHandle { fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, semaphoreId_).value); } + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// @param dst The destination memory region. + /// @param dstOffset The offset into the destination memory region. + /// @param src The source memory region. + /// @param srcOffset The offset into the source memory region. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2DWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, + uint32_t width, uint32_t height) { + fifo_.push( + ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, width, height, semaphoreId_).value); + } + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. /// @param dst The destination memory region. /// @param src The source memory region. @@ -129,6 +208,17 @@ struct ProxyChannelDeviceHandle { putWithSignal(dst, offset, src, offset, size); } + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// @param dst The destination memory region. + /// @param src The source memory region. + /// @param offset The common offset into the destination and source memory regions. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2DWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint32_t width, + uint32_t height) { + put2DWithSignal(dst, offset, src, offset, width, height); + } + /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. /// @param dst The destination memory region. /// @param dstOffset The offset into the destination memory region. @@ -178,6 +268,15 @@ struct SimpleProxyChannelDeviceHandle { proxyChan_.put(dst_, dstOffset, src_, srcOffset, size); } + /// Push a @ref TriggerData to the FIFO. + /// @param dstOffset The offset into the destination memory region. + /// @param srcOffset The offset into the source memory region. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2D(uint64_t dstOffset, uint64_t srcOffset, uint32_t width, uint32_t height) { + proxyChan_.put2D(dst_, dstOffset, src_, srcOffset, width, height); + } + /// Push a @ref TriggerData to the FIFO. /// @param offset The common offset into the destination and source memory regions. /// @param size The size of the transfer. @@ -194,11 +293,29 @@ struct SimpleProxyChannelDeviceHandle { proxyChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size); } + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// @param dstOffset The offset into the destination memory region. + /// @param srcOffset The offset into the source memory region. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2DWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint32_t width, + uint32_t height) { + proxyChan_.put2DWithSignal(dst_, dstOffset, src_, srcOffset, width, height); + } + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. /// @param offset The common offset into the destination and source memory regions. /// @param size The size of the transfer. __forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) { putWithSignal(offset, offset, size); } + /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. + /// @param offset The common offset into the destination and source memory regions. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2DWithSignal(uint64_t offset, uint32_t width, uint32_t height) { + put2DWithSignal(offset, offset, width, height); + } + /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. /// @param dstOffset The offset into the destination memory region. /// @param srcOffset The offset into the source memory region. diff --git a/python/proxy_channel_py.cpp b/python/proxy_channel_py.cpp index a483f99d..4249dd80 100644 --- a/python/proxy_channel_py.cpp +++ b/python/proxy_channel_py.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT license. #include +#include #include #include @@ -19,7 +20,13 @@ void register_proxy_channel(nb::module_& m) { .def(nb::init<>()) .def("start_proxy", &ProxyService::startProxy) .def("stop_proxy", &ProxyService::stopProxy) - .def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection")) + .def("build_and_add_semaphore", + nb::overload_cast>(&ProxyService::buildAndAddSemaphore), + nb::arg("comm"), nb::arg("connection")) + .def("build_and_add_semaphore", + nb::overload_cast, std::pair>( + &ProxyService::buildAndAddSemaphore), + nb::arg("comm"), nb::arg("connection"), nb::arg("pitch")) .def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore")) .def("add_memory", &ProxyService::addMemory, nb::arg("memory")) .def("semaphore", &ProxyService::semaphore, nb::arg("id")) diff --git a/src/connection.cc b/src/connection.cc index 112e1178..6820ad33 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -57,6 +57,20 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register // npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size); } +void CudaIpcConnection::write2D(RegisteredMemory dst, uint64_t dstOffset, uint64_t dstPitch, RegisteredMemory src, + uint64_t srcOffset, uint64_t srcPitch, uint64_t width, uint64_t height) { + validateTransport(dst, remoteTransport()); + validateTransport(src, transport()); + + char* dstPtr = (char*)dst.data(); + char* srcPtr = (char*)src.data(); + + MSCCLPP_CUDATHROW(cudaMemcpy2DAsync(dstPtr + dstOffset, dstPitch, srcPtr + srcOffset, srcPitch, width, height, + cudaMemcpyDeviceToDevice, stream_)); + INFO(MSCCLPP_P2P, "CudaIpcConnection write: from %p to %p, width %lu height %lu dstPitch %lu srcPitch %lu", + srcPtr + srcOffset, dstPtr + dstOffset, width, height, dstPitch, srcPitch); +} + void CudaIpcConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) { validateTransport(dst, remoteTransport()); uint64_t oldValue = *src; @@ -131,6 +145,11 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem // npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size); } +void IBConnection::write2D(RegisteredMemory, uint64_t, uint64_t, RegisteredMemory, uint64_t, uint64_t, uint64_t, + uint64_t) { + throw Error("write2D is not supported", ErrorCode::InvalidUsage); +} + void IBConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) { validateTransport(dst, remoteTransport()); auto dstTransportInfo = getRegisteredMemoryImpl(dst)->getTransportInfo(remoteTransport()); diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 0475691c..5ecec672 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -41,6 +41,8 @@ class CudaIpcConnection : public ConnectionBase { void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; + void write2D(RegisteredMemory dst, uint64_t dstOffset, uint64_t dstPitch, RegisteredMemory src, uint64_t srcOffset, + uint64_t srcPitch, uint64_t width, uint64_t height) override; void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override; void flush(int64_t timeoutUsec) override; @@ -65,6 +67,8 @@ class IBConnection : public ConnectionBase { void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) override; + void write2D(RegisteredMemory dst, uint64_t dstOffset, uint64_t dstPitch, RegisteredMemory src, uint64_t srcOffset, + uint64_t srcPitch, uint64_t width, uint64_t height) override; void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override; void flush(int64_t timeoutUsec) override; diff --git a/src/proxy_channel.cc b/src/proxy_channel.cc index c6a4e243..cfe6862b 100644 --- a/src/proxy_channel.cc +++ b/src/proxy_channel.cc @@ -29,6 +29,16 @@ MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& com return semaphores_.size() - 1; } +MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator, + std::shared_ptr connection, + std::pair pitch) { + semaphores_.push_back(std::make_shared(communicator, connection)); + SemaphoreId id = semaphores_.size() - 1; + if (id >= pitches_.size()) pitches_.resize(id + 1, std::pair(0, 0)); + pitches_[id] = pitch; + return id; +} + MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr semaphore) { semaphores_.push_back(semaphore); return semaphores_.size() - 1; @@ -67,8 +77,14 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) { if (trigger->fields.type & TriggerData) { RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId]; RegisteredMemory& src = memories_[trigger->fields.srcMemoryId]; - semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset, - trigger->fields.size); + if (trigger->fields2D.multiDimensionFlag) { + std::pair& pitch = pitches_.at(trigger->fields.chanId); + semaphore->connection()->write2D(dst, trigger->fields.dstOffset, pitch.first, src, trigger->fields.srcOffset, + pitch.second, trigger->fields2D.width, trigger->fields2D.height); + } else { + semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset, + trigger->fields.size); + } } if (trigger->fields.type & TriggerFlag) { diff --git a/test/mp_unit/communicator_tests.cu b/test/mp_unit/communicator_tests.cu index 829403b9..4ce14cc8 100644 --- a/test/mp_unit/communicator_tests.cu +++ b/test/mp_unit/communicator_tests.cu @@ -3,6 +3,7 @@ #include +#include #include #include @@ -143,6 +144,21 @@ void CommunicatorTest::writeToRemote(int dataCountPerRank) { } } +void CommunicatorTest::writeTileToRemote(size_t rowIndex, size_t colIndex, size_t pitch, size_t width, size_t height) { + size_t offset = rowIndex * pitch + colIndex * sizeof(int); + for (size_t n = 0; n < numBuffers; n++) { + for (int i = 0; i < gEnv->worldSize; i++) { + if (i != gEnv->rank) { + auto& conn = connections.at(i); + auto& peerMemory = remoteMemory[n].at(i); + conn->write2D(peerMemory, offset, deviceBufferPitchSize, localMemory[n], offset, deviceBufferPitchSize, + width * sizeof(int), height); + conn->flush(); + } + } + } +} + bool CommunicatorTest::testWriteCorrectness(bool skipLocal) { size_t dataCount = deviceBufferSize / sizeof(int); for (int n = 0; n < (int)devicePtr.size(); n++) { @@ -184,6 +200,45 @@ TEST_F(CommunicatorTest, BasicWrite) { communicator->bootstrap()->barrier(); } +TEST_F(CommunicatorTest, TileWrite) { + if (gEnv->rank >= numRanksToUse) return; + if (gEnv->worldSize > gEnv->nRanksPerNode) { + // tile write only support single node + GTEST_SKIP(); + } + deviceBufferInit(); + communicator->bootstrap()->barrier(); + + size_t dataSizePerRank = deviceBufferSize / gEnv->worldSize; + size_t rowCountPerRank = dataSizePerRank / deviceBufferPitchSize; + size_t colCount = deviceBufferPitchSize / sizeof(int); + // The size of the tile is . We split it into multi small tiles. + std::array, 3> nTileInDimension = {std::pair{2, 2}, {4, 4}, {8, 8}}; + for (auto& nTile : nTileInDimension) { + const int nRowPerTile = rowCountPerRank / nTile.first; + const int nColPerTile = colCount / nTile.second; + for (int xi = 0; xi < nTile.first; ++xi) { + for (int yi = 0; yi < nTile.second; ++yi) { + writeTileToRemote(rowCountPerRank * gEnv->rank + xi * nRowPerTile, yi * nColPerTile, deviceBufferPitchSize, + colCount / nTile.second, rowCountPerRank / nTile.first); + } + } + } + communicator->bootstrap()->barrier(); + + // polling until it becomes ready + bool ready = false; + int niter = 0; + do { + ready = testWriteCorrectness(); + niter++; + if (niter == 10000) { + FAIL() << "Polling is stuck."; + } + } while (!ready); + communicator->bootstrap()->barrier(); +} + __global__ void kernelWaitSemaphores(mscclpp::Host2DeviceSemaphore::DeviceHandle* deviceSemaphores, int rank, int worldSize) { int tid = threadIdx.x; diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index 39325563..67f31724 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include "ib.hpp" @@ -115,10 +116,12 @@ class CommunicatorTest : public CommunicatorTestBase { void deviceBufferInit(); void writeToRemote(int dataCountPerRank); + void writeTileToRemote(size_t rowIndex, size_t colIndex, size_t pitch, size_t width, size_t height); bool testWriteCorrectness(bool skipLocal = false); const size_t numBuffers = 10; const int deviceBufferSize = 1024 * 1024; + const int deviceBufferPitchSize = 512; std::vector> devicePtr; std::vector localMemory; std::vector> remoteMemory; @@ -134,6 +137,8 @@ class ProxyChannelOneToOneTest : public CommunicatorTestBase { void setupMeshConnections(std::vector& proxyChannels, bool useIbOnly, void* sendBuff, size_t sendBuffBytes, void* recvBuff = nullptr, size_t recvBuffBytes = 0); + void setupMeshConnections(std::vector& proxyChannels, bool useIbOnly, void* sendBuff, + size_t sendBuffBytes, size_t pitchSize, void* recvBuff = nullptr, size_t recvBuffBytes = 0); void testPacketPingPong(bool useIbOnly); void testPacketPingPongPerf(bool useIbOnly); diff --git a/test/mp_unit/proxy_channel_tests.cu b/test/mp_unit/proxy_channel_tests.cu index 5537fe01..20a1069e 100644 --- a/test/mp_unit/proxy_channel_tests.cu +++ b/test/mp_unit/proxy_channel_tests.cu @@ -17,6 +17,12 @@ void ProxyChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); } void ProxyChannelOneToOneTest::setupMeshConnections(std::vector& proxyChannels, bool useIbOnly, void* sendBuff, size_t sendBuffBytes, void* recvBuff, size_t recvBuffBytes) { + setupMeshConnections(proxyChannels, useIbOnly, sendBuff, sendBuffBytes, sendBuffBytes, recvBuff, recvBuffBytes); +} + +void ProxyChannelOneToOneTest::setupMeshConnections(std::vector& proxyChannels, + bool useIbOnly, void* sendBuff, size_t sendBuffBytes, size_t pitch, + void* recvBuff, size_t recvBuffBytes) { const int rank = communicator->bootstrap()->getRank(); const int worldSize = communicator->bootstrap()->getNranks(); const bool isInPlace = (recvBuff == nullptr); @@ -49,7 +55,12 @@ void ProxyChannelOneToOneTest::setupMeshConnections(std::vectorsetup(); - mscclpp::SemaphoreId cid = proxyService->buildAndAddSemaphore(*communicator, conn); + mscclpp::SemaphoreId cid; + if (sendBuffBytes == pitch) { + cid = proxyService->buildAndAddSemaphore(*communicator, conn); + } else { + cid = proxyService->buildAndAddSemaphore(*communicator, conn, std::pair(pitch, pitch)); + } communicator->setup(); proxyChannels.emplace_back(proxyService->proxyChannel(cid), proxyService->addMemory(remoteMemory.get()), @@ -59,6 +70,72 @@ void ProxyChannelOneToOneTest::setupMeshConnections(std::vector gChannelOneToOneTestConstProxyChans; +__device__ size_t getTileElementOffset(int elementId, int width, int rowIndex, int colIndex, int nElemPerPitch) { + int rowIndexInTile = elementId / width; + int colIndexInTile = elementId % width; + return (rowIndex + rowIndexInTile) * nElemPerPitch + (colIndex + colIndexInTile); +} + +__global__ void kernelProxyTilePingPong(int* buff, int rank, int pitch, int rowIndex, int colIndex, int width, + int height, int* ret) { + DeviceHandle& proxyChan = gChannelOneToOneTestConstProxyChans; + volatile int* sendBuff = (volatile int*)buff; + int nTries = 1000; + int flusher = 0; + size_t offset = rowIndex * pitch + colIndex * sizeof(int); + size_t nElem = width * height; + size_t nElemPerPitch = pitch / sizeof(int); + for (int i = 0; i < nTries; i++) { + if (rank == 0) { + if (i > 0) { + if (threadIdx.x == 0) proxyChan.wait(); + __syncthreads(); + for (int j = threadIdx.x; j < nElem; j += blockDim.x) { + size_t tileOffset = getTileElementOffset(j, width, rowIndex, colIndex, nElemPerPitch); + if (sendBuff[tileOffset] != offset + i - 1 + j) { + // printf("rank 0 ERROR: sendBuff[%d] = %d, expected %d\n", j, sendBuff[j], rank1Offset + i - 1 + j); + *ret = 1; + break; + } + } + } + for (int j = threadIdx.x; j < nElem; j += blockDim.x) { + size_t tileOffset = getTileElementOffset(j, width, rowIndex, colIndex, nElemPerPitch); + sendBuff[tileOffset] = i + j; + } + __syncthreads(); + // __threadfence_system(); // not necessary if we make sendBuff volatile + if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), height); + } + if (rank == 1) { + if (threadIdx.x == 0) proxyChan.wait(); + __syncthreads(); + for (int j = threadIdx.x; j < nElem; j += blockDim.x) { + size_t tileOffset = getTileElementOffset(j, width, rowIndex, colIndex, nElemPerPitch); + if (sendBuff[tileOffset] != i + j) { + // printf("rank 1 ERROR: sendBuff[%d] = %d, expected %d\n", j, sendBuff[j], i + j); + *ret = 1; + break; + } + } + if (i < nTries - 1) { + for (int j = threadIdx.x; j < nElem; j += blockDim.x) { + size_t tileOffset = getTileElementOffset(j, width, rowIndex, colIndex, nElemPerPitch); + sendBuff[tileOffset] = offset + i + j; + } + __syncthreads(); + // __threadfence_system(); // not necessary if we make sendBuff volatile + if (threadIdx.x == 0) proxyChan.put2DWithSignal(offset, width * sizeof(int), height); + } + } + flusher++; + if (flusher == 100) { + if (threadIdx.x == 0) proxyChan.flush(); + flusher = 0; + } + } +} + __global__ void kernelProxyPingPong(int* buff, int rank, int nElem, int* ret) { DeviceHandle& proxyChan = gChannelOneToOneTestConstProxyChans; volatile int* sendBuff = (volatile int*)buff; @@ -155,6 +232,54 @@ TEST_F(ProxyChannelOneToOneTest, PingPongIb) { proxyService->stopProxy(); } +TEST_F(ProxyChannelOneToOneTest, PingPongTile) { + if (gEnv->rank >= numRanksToUse) return; + if (gEnv->worldSize > gEnv->nRanksPerNode) { + // tile write only support single node + GTEST_SKIP(); + } + + const int nElem = 4 * 1024 * 1024; + + std::vector proxyChannels; + std::shared_ptr buff = mscclpp::allocSharedCuda(nElem); + const int pitchSize = 512; // the buff tile is 8192x128 + setupMeshConnections(proxyChannels, false, buff.get(), nElem * sizeof(int), pitchSize); + + ASSERT_EQ(proxyChannels.size(), 1); + MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannels.data(), + sizeof(DeviceHandle))); + + proxyService->startProxy(); + + std::shared_ptr ret = mscclpp::makeSharedCudaHost(0); + + kernelProxyTilePingPong<<<1, 1024>>>(buff.get(), gEnv->rank, pitchSize, 0, 0, 1, 1, ret.get()); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + + EXPECT_EQ(*ret, 0); + + kernelProxyTilePingPong<<<1, 1024>>>(buff.get(), gEnv->rank, pitchSize, 128, 32, 64, 64, ret.get()); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + + EXPECT_EQ(*ret, 0); + + kernelProxyTilePingPong<<<1, 1024>>>(buff.get(), gEnv->rank, pitchSize, 16, 16, 1, 8192, ret.get()); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + + EXPECT_EQ(*ret, 0); + + kernelProxyTilePingPong<<<1, 1024>>>(buff.get(), gEnv->rank, pitchSize, 5, 0, 128, 1, ret.get()); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + + EXPECT_EQ(*ret, 0); + + kernelProxyTilePingPong<<<1, 1024>>>(buff.get(), gEnv->rank, pitchSize, 0, 0, 128, 8192, ret.get()); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + + EXPECT_EQ(*ret, 0); +} + __device__ mscclpp::DeviceSyncer gChannelOneToOneTestProxyChansSyncer; template