Skip to content

Commit baa2e53

Browse files
committed
GDS
1 parent c4eb386 commit baa2e53

File tree

16 files changed

+563
-12
lines changed

16 files changed

+563
-12
lines changed

ucm/store/device/cuda/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,9 @@ target_compile_options(storedevice PRIVATE
88
--diag-suppress=128 --diag-suppress=2417 --diag-suppress=2597
99
-Wall -fPIC
1010
)
11+
add_library(Cuda::cudart UNKNOWN IMPORTED)
12+
set_target_properties(Cuda::cudart PROPERTIES
13+
INTERFACE_INCLUDE_DIRECTORIES "${CUDA_ROOT}/include"
14+
IMPORTED_LOCATION "${CUDA_ROOT}/lib64/libcudart.so"
15+
IMPORTED_LOCATION "${CUDA_ROOT}/lib64/libcufile.so"
16+
)

ucm/store/device/cuda/cuda_device.cu

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
#include <cuda_runtime.h>
2525
#include "ibuffered_device.h"
2626
#include "logger/logger.h"
27+
#include <cufile.h>
28+
#include <mutex>
29+
#include <fcntl.h>
30+
#include <unistd.h>
31+
#include <cerrno>
32+
#include <cstring>
33+
#include "sharded_handle_recorder.h"
2734

2835
#define CUDA_TRANS_UNIT_SIZE (sizeof(uint64_t) * 2)
2936
#define CUDA_TRANS_BLOCK_NUMBER (32)
@@ -90,6 +97,28 @@ struct fmt::formatter<cudaError_t> : formatter<int32_t> {
9097

9198
namespace UC {
9299

100+
static Status CreateCuFileHandle(const std::string& path, int flags, CUfileHandle_t& cuFileHandle, int& fd)
101+
{
102+
fd = open(path.c_str(), flags, 0644);
103+
if (fd < 0) {
104+
UC_ERROR("Failed to open file {}: {}", path, strerror(errno));
105+
return Status::Error();
106+
}
107+
108+
CUfileDescr_t cfDescr{};
109+
cfDescr.handle.fd = fd;
110+
cfDescr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
111+
CUfileError_t err = cuFileHandleRegister(&cuFileHandle, &cfDescr);
112+
if (err.err != CU_FILE_SUCCESS) {
113+
UC_ERROR("Failed to register cuFile handle for {}: error {}",
114+
path, static_cast<int>(err.err));
115+
close(fd);
116+
fd = -1;
117+
return Status::Error();
118+
}
119+
120+
return Status::OK();
121+
}
93122
template <typename Api, typename... Args>
94123
Status CudaApi(const char* caller, const char* file, const size_t line, const char* name, Api&& api,
95124
Args&&... args)
@@ -133,17 +162,32 @@ class CudaDevice : public IBufferedDevice {
133162
return nullptr;
134163
}
135164
static void ReleaseDeviceArray(void* deviceArray) { CUDA_API(cudaFree, deviceArray); }
165+
static std::once_flag gdsOnce_;
166+
static void InitGdsOnce();
136167

137168
public:
138169
CudaDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
139170
: IBufferedDevice{deviceId, bufferSize, bufferNumber}, stream_{nullptr}
140171
{
141172
}
142-
Status Setup() override
173+
~CudaDevice() {
174+
CuFileHandleRecorder::Instance().ClearAll([](CUfileHandle_t h, int fd) {
175+
cuFileHandleDeregister(h);
176+
if (fd >= 0) {
177+
close(fd);
178+
}
179+
});
180+
181+
if (stream_ != nullptr) {
182+
cudaStreamDestroy((cudaStream_t)stream_);
183+
}
184+
}
185+
Status Setup(bool transferUseDirect) override
143186
{
187+
if(transferUseDirect) {InitGdsOnce();}
144188
auto status = Status::OK();
145189
if ((status = CUDA_API(cudaSetDevice, this->deviceId)).Failure()) { return status; }
146-
if ((status = IBufferedDevice::Setup()).Failure()) { return status; }
190+
if ((status = IBufferedDevice::Setup(transferUseDirect)).Failure()) { return status; }
147191
if ((status = CUDA_API(cudaStreamCreate, (cudaStream_t*)&this->stream_)).Failure()) {
148192
return status;
149193
}
@@ -165,6 +209,40 @@ public:
165209
{
166210
return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyDeviceToHost, this->stream_);
167211
}
212+
Status S2DSync(const std::string& path, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
213+
{
214+
CUfileHandle_t cuFileHandle = nullptr;
215+
auto status = CuFileHandleRecorder::Instance().Get(path, cuFileHandle,
216+
[&path](CUfileHandle_t& handle, int& fd) -> Status {
217+
return CreateCuFileHandle(path, O_RDONLY | O_DIRECT, handle, fd);
218+
});
219+
if (status.Failure()) {
220+
return status;
221+
}
222+
ssize_t bytesRead = cuFileRead(cuFileHandle, address, length, fileOffset, devOffset);
223+
if (bytesRead < 0 || (size_t)bytesRead != length) {
224+
UC_ERROR("cuFileRead failed for {}: expected {}, got {}", path, length, bytesRead);
225+
return Status::Error();
226+
}
227+
return Status::OK();
228+
}
229+
Status D2SSync(const std::string& path, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
230+
{
231+
CUfileHandle_t cuFileHandle = nullptr;
232+
auto status = CuFileHandleRecorder::Instance().Get(path, cuFileHandle,
233+
[&path](CUfileHandle_t& handle, int& fd) -> Status {
234+
return CreateCuFileHandle(path, O_WRONLY | O_CREAT | O_DIRECT, handle, fd);
235+
});
236+
if (status.Failure()) {
237+
return status;
238+
}
239+
ssize_t bytesWrite = cuFileWrite(cuFileHandle, address, length, fileOffset, devOffset);
240+
if (bytesWrite < 0 || (size_t)bytesWrite != length) {
241+
UC_ERROR("cuFileWrite failed for {}: expected {}, got {}", path, length, bytesWrite);
242+
return Status::Error();
243+
}
244+
return Status::OK();
245+
}
168246
Status AppendCallback(std::function<void(bool)> cb) override
169247
{
170248
auto* c = new (std::nothrow) Closure(cb);
@@ -237,5 +315,17 @@ std::unique_ptr<IDevice> DeviceFactory::Make(const int32_t deviceId, const size_
237315
return nullptr;
238316
}
239317
}
318+
std::once_flag CudaDevice::gdsOnce_{};
319+
void CudaDevice::InitGdsOnce()
320+
{
321+
std::call_once(gdsOnce_, [] (){
322+
CUfileError_t ret = cuFileDriverOpen();
323+
if (ret.err == CU_FILE_SUCCESS) {
324+
UC_INFO("GDS driver initialized successfully");
325+
} else {
326+
UC_ERROR("GDS driver initialized unsuccessfully");
327+
}
328+
});
329+
}
240330

241331
} // namespace UC
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#ifndef UC_INFRA_SHARDED_HANDLE_RECORDER_H
2+
#define UC_INFRA_SHARDED_HANDLE_RECORDER_H
3+
4+
#include <functional>
5+
#include <string>
6+
#include "status/status.h"
7+
#include <cufile.h>
8+
#include "infra/template/hashmap.h"
9+
10+
namespace UC {
11+
12+
class CuFileHandleRecorder {
13+
private:
14+
struct RecordValue {
15+
CUfileHandle_t handle;
16+
int fd;
17+
uint64_t refCount;
18+
};
19+
using HandleMap = HashMap<std::string, RecordValue, std::hash<std::string>, 10>;
20+
HandleMap handles_;
21+
CuFileHandleRecorder() = default;
22+
CuFileHandleRecorder(const CuFileHandleRecorder&) = delete;
23+
CuFileHandleRecorder& operator=(const CuFileHandleRecorder&) = delete;
24+
25+
public:
26+
static CuFileHandleRecorder& Instance()
27+
{
28+
static CuFileHandleRecorder recorder;
29+
return recorder;
30+
}
31+
32+
Status Get(const std::string& path, CUfileHandle_t& handle,
33+
std::function<Status(CUfileHandle_t&, int&)> instantiate)
34+
{
35+
auto result = handles_.GetOrCreate(path, [&instantiate](RecordValue& value) -> bool {
36+
int fd = -1;
37+
CUfileHandle_t h = nullptr;
38+
39+
auto status = instantiate(h, fd);
40+
if (status.Failure()) {
41+
return false;
42+
}
43+
44+
value.handle = h;
45+
value.fd = fd;
46+
value.refCount = 1;
47+
return true;
48+
});
49+
50+
if (!result.has_value()) {
51+
return Status::Error();
52+
}
53+
54+
auto& recordValue = result.value().get();
55+
recordValue.refCount++;
56+
handle = recordValue.handle;
57+
return Status::OK();
58+
}
59+
60+
void Put(const std::string& path,
61+
std::function<void(CUfileHandle_t)> cleanup)
62+
{
63+
handles_.Upsert(path, [&cleanup](RecordValue& value) -> bool {
64+
value.refCount--;
65+
if (value.refCount > 0) {
66+
return false;
67+
}
68+
cleanup(value.handle);
69+
return true;
70+
});
71+
}
72+
73+
void ClearAll(std::function<void(CUfileHandle_t, int)> cleanup)
74+
{
75+
handles_.ForEach([&cleanup](const std::string& path, RecordValue& value) {
76+
cleanup(value.handle, value.fd);
77+
});
78+
handles_.Clear();
79+
}
80+
};
81+
82+
} // namespace UC
83+
84+
#endif // UC_INFRA_SHARDED_HANDLE_RECORDER_H

ucm/store/device/ibuffered_device.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ class IBufferedDevice : public IDevice {
3535
: IDevice{deviceId, bufferSize, bufferNumber}
3636
{
3737
}
38-
Status Setup() override
38+
Status Setup(bool transferUseDirect) override
3939
{
40+
if(transferUseDirect) {return Status::OK();}
4041
auto totalSize = this->bufferSize * this->bufferNumber;
4142
if (totalSize == 0) { return Status::OK(); }
4243
this->_addr = this->MakeBuffer(totalSize);

ucm/store/device/idevice.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class IDevice {
3737
{
3838
}
3939
virtual ~IDevice() = default;
40-
virtual Status Setup() = 0;
40+
virtual Status Setup(bool transferUseDirect) = 0;
4141
virtual std::shared_ptr<std::byte> GetBuffer(const size_t size) = 0;
4242
virtual Status H2DSync(std::byte* dst, const std::byte* src, const size_t count) = 0;
4343
virtual Status D2HSync(std::byte* dst, const std::byte* src, const size_t count) = 0;
@@ -49,6 +49,8 @@ class IDevice {
4949
const size_t count) = 0;
5050
virtual Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number,
5151
const size_t count) = 0;
52+
virtual Status S2DSync(const std::string& path, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) = 0;
53+
virtual Status D2SSync(const std::string& path, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) = 0;
5254

5355
protected:
5456
virtual std::shared_ptr<std::byte> MakeBuffer(const size_t size) = 0;

0 commit comments

Comments
 (0)