2424#include < cuda_runtime.h>
2525#include " ibuffered_device.h"
2626#include " logger/logger.h"
27+ #include < cufile.h>
28+ #include < mutex>
29+ #include < fcntl.h>
30+ #include < unistd.h>
31+ #include < cerrno>
32+ #include < cstring>
33+ #include " sharded_handle_recorder.h"
2734
2835#define CUDA_TRANS_UNIT_SIZE (sizeof (uint64_t ) * 2 )
2936#define CUDA_TRANS_BLOCK_NUMBER (32 )
@@ -90,6 +97,28 @@ struct fmt::formatter<cudaError_t> : formatter<int32_t> {
9097
9198namespace UC {
9299
100+ static Status CreateCuFileHandle (const std::string& path, int flags, CUfileHandle_t& cuFileHandle, int & fd)
101+ {
102+ fd = open (path.c_str (), flags, 0644 );
103+ if (fd < 0 ) {
104+ UC_ERROR (" Failed to open file {}: {}" , path, strerror (errno));
105+ return Status::Error ();
106+ }
107+
108+ CUfileDescr_t cfDescr{};
109+ cfDescr.handle .fd = fd;
110+ cfDescr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
111+ CUfileError_t err = cuFileHandleRegister (&cuFileHandle, &cfDescr);
112+ if (err.err != CU_FILE_SUCCESS) {
113+ UC_ERROR (" Failed to register cuFile handle for {}: error {}" ,
114+ path, static_cast <int >(err.err ));
115+ close (fd);
116+ fd = -1 ;
117+ return Status::Error ();
118+ }
119+
120+ return Status::OK ();
121+ }
93122template <typename Api, typename ... Args>
94123Status CudaApi (const char * caller, const char * file, const size_t line, const char * name, Api&& api,
95124 Args&&... args)
@@ -133,17 +162,32 @@ class CudaDevice : public IBufferedDevice {
133162 return nullptr ;
134163 }
135164 static void ReleaseDeviceArray (void * deviceArray) { CUDA_API (cudaFree, deviceArray); }
165+ static std::once_flag gdsOnce_;
166+ static void InitGdsOnce ();
136167
137168public:
138169 CudaDevice (const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
139170 : IBufferedDevice{deviceId, bufferSize, bufferNumber}, stream_{nullptr }
140171 {
141172 }
142- Status Setup () override
173+ ~CudaDevice () {
174+ CuFileHandleRecorder::Instance ().ClearAll ([](CUfileHandle_t h, int fd) {
175+ cuFileHandleDeregister (h);
176+ if (fd >= 0 ) {
177+ close (fd);
178+ }
179+ });
180+
181+ if (stream_ != nullptr ) {
182+ cudaStreamDestroy ((cudaStream_t)stream_);
183+ }
184+ }
185+ Status Setup (bool transferUseDirect) override
143186 {
187+ if (transferUseDirect) {InitGdsOnce ();}
144188 auto status = Status::OK ();
145189 if ((status = CUDA_API (cudaSetDevice, this ->deviceId )).Failure ()) { return status; }
146- if ((status = IBufferedDevice::Setup ()).Failure ()) { return status; }
190+ if ((status = IBufferedDevice::Setup (transferUseDirect )).Failure ()) { return status; }
147191 if ((status = CUDA_API (cudaStreamCreate, (cudaStream_t*)&this ->stream_ )).Failure ()) {
148192 return status;
149193 }
@@ -165,6 +209,40 @@ public:
165209 {
166210 return CUDA_API (cudaMemcpyAsync, dst, src, count, cudaMemcpyDeviceToHost, this ->stream_ );
167211 }
212+ Status S2DSync (const std::string& path, void * address, const size_t length, const size_t fileOffset, const size_t devOffset) override
213+ {
214+ CUfileHandle_t cuFileHandle = nullptr ;
215+ auto status = CuFileHandleRecorder::Instance ().Get (path, cuFileHandle,
216+ [&path](CUfileHandle_t& handle, int & fd) -> Status {
217+ return CreateCuFileHandle (path, O_RDONLY | O_DIRECT, handle, fd);
218+ });
219+ if (status.Failure ()) {
220+ return status;
221+ }
222+ ssize_t bytesRead = cuFileRead (cuFileHandle, address, length, fileOffset, devOffset);
223+ if (bytesRead < 0 || (size_t )bytesRead != length) {
224+ UC_ERROR (" cuFileRead failed for {}: expected {}, got {}" , path, length, bytesRead);
225+ return Status::Error ();
226+ }
227+ return Status::OK ();
228+ }
229+ Status D2SSync (const std::string& path, void * address, const size_t length, const size_t fileOffset, const size_t devOffset) override
230+ {
231+ CUfileHandle_t cuFileHandle = nullptr ;
232+ auto status = CuFileHandleRecorder::Instance ().Get (path, cuFileHandle,
233+ [&path](CUfileHandle_t& handle, int & fd) -> Status {
234+ return CreateCuFileHandle (path, O_WRONLY | O_CREAT | O_DIRECT, handle, fd);
235+ });
236+ if (status.Failure ()) {
237+ return status;
238+ }
239+ ssize_t bytesWrite = cuFileWrite (cuFileHandle, address, length, fileOffset, devOffset);
240+ if (bytesWrite < 0 || (size_t )bytesWrite != length) {
241+ UC_ERROR (" cuFileWrite failed for {}: expected {}, got {}" , path, length, bytesWrite);
242+ return Status::Error ();
243+ }
244+ return Status::OK ();
245+ }
168246 Status AppendCallback (std::function<void (bool )> cb) override
169247 {
170248 auto * c = new (std::nothrow) Closure (cb);
@@ -237,5 +315,17 @@ std::unique_ptr<IDevice> DeviceFactory::Make(const int32_t deviceId, const size_
237315 return nullptr ;
238316 }
239317}
318+ std::once_flag CudaDevice::gdsOnce_{};
319+ void CudaDevice::InitGdsOnce ()
320+ {
321+ std::call_once (gdsOnce_, [] (){
322+ CUfileError_t ret = cuFileDriverOpen ();
323+ if (ret.err == CU_FILE_SUCCESS) {
324+ UC_INFO (" GDS driver initialized successfully" );
325+ } else {
326+ UC_ERROR (" GDS driver initialized unsuccessfully" );
327+ }
328+ });
329+ }
240330
241331} // namespace UC
0 commit comments