2424#include < cuda_runtime.h>
2525#include " ibuffered_device.h"
2626#include " logger/logger.h"
27+ #include < cufile.h>
28+ #include < mutex>
29+ #include < fcntl.h>
30+ #include < unistd.h>
31+ #include < cerrno>
32+ #include < cstring>
33+ #include < unordered_map>
34+ #include < cstdlib>
35+ #include " infra/template/handle_recorder.h"
2736
2837#define CUDA_TRANS_UNIT_SIZE (sizeof (uint64_t ) * 2 )
2938#define CUDA_TRANS_BLOCK_NUMBER (32 )
@@ -90,6 +99,25 @@ struct fmt::formatter<cudaError_t> : formatter<int32_t> {
9099
91100namespace UC {
92101
102+ static Status CreateCuFileHandle (int fd, CUfileHandle_t& cuFileHandle)
103+ {
104+ if (fd < 0 ) {
105+ UC_ERROR (" Invalid file descriptor: {}" , fd);
106+ return Status::Error ();
107+ }
108+
109+ CUfileDescr_t cfDescr{};
110+ cfDescr.handle .fd = fd;
111+ cfDescr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
112+ CUfileError_t err = cuFileHandleRegister (&cuFileHandle, &cfDescr);
113+ if (err.err != CU_FILE_SUCCESS) {
114+ UC_ERROR (" Failed to register cuFile handle for fd {}: error {}" ,
115+ fd, static_cast <int >(err.err ));
116+ return Status::Error ();
117+ }
118+
119+ return Status::OK ();
120+ }
93121template <typename Api, typename ... Args>
94122Status CudaApi (const char * caller, const char * file, const size_t line, const char * name, Api&& api,
95123 Args&&... args)
@@ -133,12 +161,23 @@ class CudaDevice : public IBufferedDevice {
133161 return nullptr ;
134162 }
135163 static void ReleaseDeviceArray (void * deviceArray) { CUDA_API (cudaFree, deviceArray); }
164+ static std::once_flag gdsOnce_;
136165
137166public:
167+ static Status InitGdsOnce ();
138168 CudaDevice (const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
139169 : IBufferedDevice{deviceId, bufferSize, bufferNumber}, stream_{nullptr }
140170 {
141171 }
172+ ~CudaDevice () {
173+ HandlePool<int , CUfileHandle_t>::Instance ().ClearAll ([](CUfileHandle_t h) {
174+ cuFileHandleDeregister (h);
175+ });
176+
177+ if (stream_ != nullptr ) {
178+ cudaStreamDestroy ((cudaStream_t)stream_);
179+ }
180+ }
142181 Status Setup () override
143182 {
144183 auto status = Status::OK ();
@@ -165,6 +204,52 @@ public:
165204 {
166205 return CUDA_API (cudaMemcpyAsync, dst, src, count, cudaMemcpyDeviceToHost, this ->stream_ );
167206 }
207+ Status S2DSync (int fd, void * address, const size_t length, const size_t fileOffset, const size_t devOffset) override
208+ {
209+ CUfileHandle_t cuFileHandle = nullptr ;
210+ auto status = HandlePool<int , CUfileHandle_t>::Instance ().Get (fd, cuFileHandle,
211+ [fd](CUfileHandle_t& handle) -> Status {
212+ return CreateCuFileHandle (fd, handle);
213+ });
214+ if (status.Failure ()) {
215+ return status;
216+ }
217+ ssize_t bytesRead = cuFileRead (cuFileHandle, address, length, fileOffset, devOffset);
218+ HandlePool<int , CUfileHandle_t>::Instance ().Put (fd, [](CUfileHandle_t h) {
219+ if (h != nullptr ) {
220+ cuFileHandleDeregister (h);
221+ }
222+ });
223+
224+ if (bytesRead < 0 || (size_t )bytesRead != length) {
225+ UC_ERROR (" cuFileRead failed for fd {}: expected {}, got {}" , fd, length, bytesRead);
226+ return Status::Error ();
227+ }
228+ return Status::OK ();
229+ }
230+ Status D2SSync (int fd, void * address, const size_t length, const size_t fileOffset, const size_t devOffset) override
231+ {
232+ CUfileHandle_t cuFileHandle = nullptr ;
233+ auto status = HandlePool<int , CUfileHandle_t>::Instance ().Get (fd, cuFileHandle,
234+ [fd](CUfileHandle_t& handle) -> Status {
235+ return CreateCuFileHandle (fd, handle);
236+ });
237+ if (status.Failure ()) {
238+ return status;
239+ }
240+ ssize_t bytesWrite = cuFileWrite (cuFileHandle, address, length, fileOffset, devOffset);
241+ HandlePool<int , CUfileHandle_t>::Instance ().Put (fd, [](CUfileHandle_t h) {
242+ if (h != nullptr ) {
243+ cuFileHandleDeregister (h);
244+ }
245+ });
246+
247+ if (bytesWrite < 0 || (size_t )bytesWrite != length) {
248+ UC_ERROR (" cuFileWrite failed for fd {}: expected {}, got {}" , fd, length, bytesWrite);
249+ return Status::Error ();
250+ }
251+ return Status::OK ();
252+ }
168253 Status AppendCallback (std::function<void (bool )> cb) override
169254 {
170255 auto * c = new (std::nothrow) Closure (cb);
@@ -226,6 +311,14 @@ private:
226311 cudaStream_t stream_;
227312};
228313
314+ Status DeviceFactory::Setup (bool useDirect)
315+ {
316+ if (useDirect) {
317+ return CudaDevice::InitGdsOnce ();
318+ }
319+ return Status::OK ();
320+ }
321+
229322std::unique_ptr<IDevice> DeviceFactory::Make (const int32_t deviceId, const size_t bufferSize,
230323 const size_t bufferNumber)
231324{
@@ -237,5 +330,20 @@ std::unique_ptr<IDevice> DeviceFactory::Make(const int32_t deviceId, const size_
237330 return nullptr ;
238331 }
239332}
333+ std::once_flag CudaDevice::gdsOnce_{};
334+ Status CudaDevice::InitGdsOnce ()
335+ {
336+ Status result = Status::OK ();
337+ std::call_once (gdsOnce_, [&result]() {
338+ CUfileError_t ret = cuFileDriverOpen ();
339+ if (ret.err == CU_FILE_SUCCESS) {
340+ UC_INFO (" GDS driver initialized successfully" );
341+ } else {
342+ UC_ERROR (" GDS driver initialization failed with error code: {}" , static_cast <int >(ret.err ));
343+ result = Status::Error ();
344+ }
345+ });
346+ return result;
347+ }
240348
241349} // namespace UC
0 commit comments